diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..83cfd8dbb643612f79f25d84b65ac7e4b3c4fb7f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2f78cf5b66514f2506d9af5f3dadf3dee7aa6d9f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.pyc + diff --git a/Examples/Beethoven.wav b/Examples/Beethoven.wav new file mode 100755 index 0000000000000000000000000000000000000000..5b2d61c0fbebcbb1ef2e040cc975cf45af248337 --- /dev/null +++ b/Examples/Beethoven.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:30a6a087a9e0eb87422aa3b48ad966eabb1dfe105d73a25d356b71d3aee31493 +size 4828972 diff --git a/Examples/Beethoven_arcade.wav b/Examples/Beethoven_arcade.wav new file mode 100644 index 0000000000000000000000000000000000000000..8a6bf3f4e642681ce95e7e8c604862d22f61046a --- /dev/null +++ b/Examples/Beethoven_arcade.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ccd929b93c15706f2102a27973d490a84ce0eb97faba6a92ece0c6d81ed2c26e +size 1794746 diff --git a/Examples/Beethoven_piano.wav b/Examples/Beethoven_piano.wav new file mode 100644 index 0000000000000000000000000000000000000000..0dedc751b223a8209344c92129a7ca45760f08f1 --- /dev/null +++ b/Examples/Beethoven_piano.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5787c31b0b3c78dec33d651d437364785713042e7cfce2290cf4baf01f65ac6f +size 1794746 diff --git a/Examples/Cat.wav b/Examples/Cat.wav new file mode 100644 index 0000000000000000000000000000000000000000..bb849ac6f86dadcb51af1992415b4428901e7b77 --- /dev/null +++ b/Examples/Cat.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:27b43763a8d9ac90dc78285ed9817b16524f24b4f4d1aa399616f1a04d4a9fd9 +size 1920508 diff --git a/Examples/Cat_dog.wav b/Examples/Cat_dog.wav new file mode 100644 index 0000000000000000000000000000000000000000..3e400307510e72d7edf889c6fc17737351c64387 --- /dev/null +++ b/Examples/Cat_dog.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:90a97dc229eeccef307dd40db97fb09cc439ce0b45a320fd84b2ea6b03d0deb2 +size 327822 diff --git a/Examples/ModalJazz.wav b/Examples/ModalJazz.wav new file mode 100644 index 0000000000000000000000000000000000000000..a29553d312f8bd8071512c8c85ff5ead79047c61 --- /dev/null +++ b/Examples/ModalJazz.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:846a77046d21ebc3996841404eede9d56797c82b3414025e1ccafe586eaf2959 +size 9153322 diff --git a/Examples/ModalJazz_banjo.wav b/Examples/ModalJazz_banjo.wav new file mode 100644 index 0000000000000000000000000000000000000000..99a51a2b14e03075aecf18a7ea5f9f0acafde473 --- /dev/null +++ b/Examples/ModalJazz_banjo.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:122e0078c0bf2fc96425071706fe0e8674c93cc1d2787fd02c0e2c0f12de5cc5 +size 6802106 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..1532c5b27dbedfcaf7eeaf9e8f76a340bf61e43a --- /dev/null +++ b/app.py @@ -0,0 +1,271 @@ +import gradio as gr +import random +import torch +import torchaudio +from torch import inference_mode +from tempfile import NamedTemporaryFile +import numpy as np +from models import load_model +import utils +from inversion_utils import inversion_forward_process, inversion_reverse_process + + +def randomize_seed_fn(seed, randomize_seed): + if randomize_seed: + seed = random.randint(0, np.iinfo(np.int32).max) + torch.manual_seed(seed) + return seed + + +def invert(x0, prompt_src, num_diffusion_steps, cfg_scale_src): # , ldm_stable): + ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device) + + with inference_mode(): + w0 = ldm_stable.vae_encode(x0) + + # find Zs and wts - forward process + _, zs, wts = inversion_forward_process(ldm_stable, w0, etas=1, + prompts=[prompt_src], + cfg_scales=[cfg_scale_src], + prog_bar=True, + num_inference_steps=num_diffusion_steps, + numerical_fix=True) + return zs, wts + + +def sample(zs, wts, steps, prompt_tar, tstart, cfg_scale_tar): # , ldm_stable): + # reverse process (via Zs and wT) + tstart = torch.tensor(tstart, dtype=torch.int) + skip = steps - tstart + w0, _ = inversion_reverse_process(ldm_stable, xT=wts, skips=steps - skip, + etas=1., prompts=[prompt_tar], + neg_prompts=[""], cfg_scales=[cfg_scale_tar], + prog_bar=True, + zs=zs[:int(steps - skip)]) + + # vae decode image + with inference_mode(): + x0_dec = ldm_stable.vae_decode(w0) + if x0_dec.dim() < 4: + x0_dec = x0_dec[None, :, :, :] + + with torch.no_grad(): + audio = ldm_stable.decode_to_mel(x0_dec) + + f = NamedTemporaryFile("wb", suffix=".wav", delete=False) + torchaudio.save(f.name, audio, sample_rate=16000) + + return f.name + + +def edit(input_audio, + model_id: str, + do_inversion: bool, + wts: gr.State, zs: gr.State, saved_inv_model: str, + source_prompt="", + target_prompt="", + steps=200, + cfg_scale_src=3.5, + cfg_scale_tar=12, + t_start=90, + randomize_seed=True): + + global ldm_stable, current_loaded_model + print(f'current loaded model: {ldm_stable.model_id}') + if model_id != current_loaded_model: + print(f'Changing model to {model_id}...') + current_loaded_model = model_id + ldm_stable = None + ldm_stable = load_model(model_id, device, steps) + + # If the inversion was done for a different model, we need to re-run the inversion + if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id): + do_inversion = True + + x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device) + + if do_inversion or randomize_seed: # always re-run inversion + zs_tensor, wts_tensor = invert(x0=x0, prompt_src=source_prompt, + num_diffusion_steps=steps, + cfg_scale_src=cfg_scale_src) + wts = gr.State(value=wts_tensor) + zs = gr.State(value=zs_tensor) + saved_inv_model = model_id + do_inversion = False + + output = sample(zs.value, wts.value, steps, prompt_tar=target_prompt, tstart=t_start, + cfg_scale_tar=cfg_scale_tar) + + return output, wts, zs, saved_inv_model, do_inversion + + +current_loaded_model = "cvssp/audioldm2-music" +# current_loaded_model = "cvssp/audioldm2-music" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +ldm_stable = load_model(current_loaded_model, device, 200) # deafult model + + +def get_example(): + case = [ + ['Examples/Beethoven.wav', + '', + 'A recording of an arcade game soundtrack.', + 90, + 'cvssp/audioldm2-music', + '27s', + 'Examples/Beethoven_arcade.wav', + ], + ['Examples/Beethoven.wav', + 'A high quality recording of wind instruments and strings playing.', + 'A high quality recording of a piano playing.', + 90, + 'cvssp/audioldm2-music', + '27s', + 'Examples/Beethoven_piano.wav', + ], + ['Examples/ModalJazz.wav', + 'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.', + 'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.', + 90, + 'cvssp/audioldm2-music', + '106s', + 'Examples/ModalJazz_banjo.wav',], + ['Examples/Cat.wav', + '', + 'A dog barking.', + 150, + 'cvssp/audioldm2-large', + '10s', + 'Examples/Cat_dog.wav',] + ] + return case + + +intro = """ +

Zero-Shot Text-Based Audio Editing Using DDPM Inversion

+

+ [Paper] |  + [Project page] |  + [Code] +

+

+Demo for the text-based editing method introduced in: + Zero-Shot Unsupervised and Text-Based Audio Editing Using DDPM Inversion +

+

+Instructions:
+Provide an input audio and a target prompt to edit the audio.
+Tstart is used to control the tradeoff between fidelity to the original signal and text-adhearance. +Lower value -> favor fidelity. Higher value -> apply a stronger edit.
+Make sure that you use an AudioLDM2 version that is suitable for your input audio. +For example, use the music version for music and the large version for general audio. +

+

+You can additionally provide a source prompt to guide even further the editing process. +

+

Longer input will take more time.

+

+For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. + +Duplicate Space +

+ +""" + +with gr.Blocks(css='style.css') as demo: + def reset_do_inversion(): + do_inversion = gr.State(value=True) + return do_inversion + + gr.HTML(intro) + wts = gr.State() + zs = gr.State() + saved_inv_model = gr.State() + # current_loaded_model = gr.State(value="cvssp/audioldm2-music") + # ldm_stable = load_model("cvssp/audioldm2-music", device, 200) + # ldm_stable = gr.State(value=ldm_stable) + do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over + + with gr.Row(): + with gr.Column(): + src_prompt = gr.Textbox(label="OPTIONAL: Source Prompt", lines=2, interactive=True, + placeholder="Optional: Describe the original audio input",) + input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Input Audio", + interactive=True, scale=1) + + with gr.Column(): + tar_prompt = gr.Textbox(label="Target Prompt", placeholder="Describe your desired edited output", + lines=2, interactive=True) + output_audio = gr.Audio(label="Edited Audio", interactive=False, scale=1) + + with gr.Row(): + with gr.Column(): + submit = gr.Button("Edit") + + with gr.Row(): + t_start = gr.Slider(minimum=30, maximum=160, value=110, step=1, label="T-start", interactive=True, scale=3, + info="Higher T-start -> stronger edit. Lower T-start -> more similar to original audio.") + model_id = gr.Dropdown(label="AudioLDM2 Version", choices=["cvssp/audioldm2", + "cvssp/audioldm2-large", + "cvssp/audioldm2-music"], + info="Choose a checkpoint suitable for your intended audio and edit.", + value="cvssp/audioldm2-music", interactive=True, type="value", scale=2) + with gr.Accordion("More Options", open=False): + + with gr.Row(): + cfg_scale_src = gr.Number(value=3, minimum=0.5, maximum=25, precision=None, + label="Source Guidance Scale", interactive=True, scale=1) + cfg_scale_tar = gr.Number(value=12, minimum=0.5, maximum=25, precision=None, + label="Target Guidance Scale", interactive=True, scale=1) + steps = gr.Number(value=200, precision=0, minimum=20, maximum=1000, + label="Num Diffusion Steps", interactive=True, scale=1) + with gr.Row(): + seed = gr.Number(value=0, precision=0, label="Seed", interactive=True) + randomize_seed = gr.Checkbox(label='Randomize seed', value=False) + length = gr.Number(label="Length", interactive=False, visible=False) + + def change_tstart_range(steps): + t_start.maximum = int(160/200 * steps) + t_start.minimum = int(30/200 * steps) + if t_start.value > t_start.maximum: + t_start.value = t_start.maximum + if t_start.value < t_start.minimum: + t_start.value = t_start.minimum + return t_start + + submit.click( + fn=randomize_seed_fn, + inputs=[seed, randomize_seed], + outputs=[seed], queue=False).then( + fn=edit, + inputs=[input_audio, + model_id, + do_inversion, + # current_loaded_model, ldm_stable, + wts, zs, saved_inv_model, + src_prompt, + tar_prompt, + steps, + cfg_scale_src, + cfg_scale_tar, + t_start, + randomize_seed + ], + outputs=[output_audio, wts, zs, saved_inv_model, do_inversion] # , current_loaded_model, ldm_stable], + ) + + # If sources changed we have to rerun inversion + input_audio.change(fn=reset_do_inversion, outputs=[do_inversion]) + src_prompt.change(fn=reset_do_inversion, outputs=[do_inversion]) + model_id.change(fn=reset_do_inversion, outputs=[do_inversion]) + steps.change(fn=change_tstart_range, inputs=[steps], outputs=[t_start]) + + gr.Examples( + label="Examples", + examples=get_example(), + inputs=[input_audio, src_prompt, tar_prompt, t_start, model_id, length, output_audio], + outputs=[output_audio] + ) + + demo.queue() + demo.launch() diff --git a/audioldm/__init__.py b/audioldm/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..2bbf85f01ccc72b6f18e7405d940adf07a26b500 --- /dev/null +++ b/audioldm/__init__.py @@ -0,0 +1,8 @@ +from .ldm import LatentDiffusion +from .utils import seed_everything, save_wave, get_time, get_duration +from .pipeline import * + + + + + diff --git a/audioldm/__main__.py b/audioldm/__main__.py new file mode 100755 index 0000000000000000000000000000000000000000..fd597a810642bf518f7bbbc113e134526d918ba7 --- /dev/null +++ b/audioldm/__main__.py @@ -0,0 +1,183 @@ +#!/usr/bin/python3 +import os +from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration +import argparse + +CACHE_DIR = os.getenv( + "AUDIOLDM_CACHE_DIR", + os.path.join(os.path.expanduser("~"), ".cache/audioldm")) + +parser = argparse.ArgumentParser() + +parser.add_argument( + "--mode", + type=str, + required=False, + default="generation", + help="generation: text-to-audio generation; transfer: style transfer", + choices=["generation", "transfer"] +) + +parser.add_argument( + "-t", + "--text", + type=str, + required=False, + default="", + help="Text prompt to the model for audio generation", +) + +parser.add_argument( + "-f", + "--file_path", + type=str, + required=False, + default=None, + help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio", +) + +parser.add_argument( + "--transfer_strength", + type=float, + required=False, + default=0.5, + help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text", +) + +parser.add_argument( + "-s", + "--save_path", + type=str, + required=False, + help="The path to save model output", + default="./output", +) + +parser.add_argument( + "--model_name", + type=str, + required=False, + help="The checkpoint you gonna use", + default="audioldm-m-full", + choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2","audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full"] +) + +parser.add_argument( + "-ckpt", + "--ckpt_path", + type=str, + required=False, + help="The path to the pretrained .ckpt model", + default=None, +) + +parser.add_argument( + "-b", + "--batchsize", + type=int, + required=False, + default=1, + help="Generate how many samples at the same time", +) + +parser.add_argument( + "--ddim_steps", + type=int, + required=False, + default=200, + help="The sampling step for DDIM", +) + +parser.add_argument( + "-gs", + "--guidance_scale", + type=float, + required=False, + default=2.5, + help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", +) + +parser.add_argument( + "-dur", + "--duration", + type=float, + required=False, + default=10.0, + help="The duration of the samples", +) + +parser.add_argument( + "-n", + "--n_candidate_gen_per_text", + type=int, + required=False, + default=3, + help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", +) + +parser.add_argument( + "--seed", + type=int, + required=False, + default=42, + help="Change this value (any integer number) will lead to a different generation result.", +) + +args = parser.parse_args() + +if(args.ckpt_path is not None): + print("Warning: ckpt_path has no effect after version 0.0.20.") + +assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5" + +mode = args.mode +if(mode == "generation" and args.file_path is not None): + mode = "generation_audio_to_audio" + if(len(args.text) > 0): + print("Warning: You have specified the --file_path. --text will be ignored") + args.text = "" + +save_path = os.path.join(args.save_path, mode) + +if(args.file_path is not None): + save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0])) + +text = args.text +random_seed = args.seed +duration = args.duration +guidance_scale = args.guidance_scale +n_candidate_gen_per_text = args.n_candidate_gen_per_text + +os.makedirs(save_path, exist_ok=True) +audioldm = build_model(model_name=args.model_name) + +if(args.mode == "generation"): + waveform = text_to_audio( + audioldm, + text, + args.file_path, + random_seed, + duration=duration, + guidance_scale=guidance_scale, + ddim_steps=args.ddim_steps, + n_candidate_gen_per_text=n_candidate_gen_per_text, + batchsize=args.batchsize, + ) + +elif(args.mode == "transfer"): + assert args.file_path is not None + assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path + waveform = style_transfer( + audioldm, + text, + args.file_path, + args.transfer_strength, + random_seed, + duration=duration, + guidance_scale=guidance_scale, + ddim_steps=args.ddim_steps, + batchsize=args.batchsize, + ) + waveform = waveform[:,None,:] + +save_wave(waveform, save_path, name="%s_%s" % (get_time(), text)) diff --git a/audioldm/audio/__init__.py b/audioldm/audio/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..56902e96f041bc4ba6bfadd7a7742023b9560233 --- /dev/null +++ b/audioldm/audio/__init__.py @@ -0,0 +1,2 @@ +from .tools import wav_to_fbank, read_wav_file +from .stft import TacotronSTFT diff --git a/audioldm/audio/audio_processing.py b/audioldm/audio/audio_processing.py new file mode 100755 index 0000000000000000000000000000000000000000..77a4057aa82f226f68474f4c2a19eba84510d663 --- /dev/null +++ b/audioldm/audio/audio_processing.py @@ -0,0 +1,100 @@ +import torch +import numpy as np +import librosa.util as librosa_util +from scipy.signal import get_window + + +def window_sumsquare( + window, + n_frames, + hop_length, + win_length, + n_fft, + dtype=np.float32, + norm=None, +): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] + return x + + +def griffin_lim(magnitudes, stft_fn, n_iters=30): + """ + PARAMS + ------ + magnitudes: spectrogram magnitudes + stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods + """ + + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) + angles = angles.astype(np.float32) + angles = torch.autograd.Variable(torch.from_numpy(angles)) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + + for i in range(n_iters): + _, angles = stft_fn.transform(signal) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + return signal + + +def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return normalize_fun(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C diff --git a/audioldm/audio/stft.py b/audioldm/audio/stft.py new file mode 100755 index 0000000000000000000000000000000000000000..6af485d45d5d371e7c9bc72531497fa7f7a716c1 --- /dev/null +++ b/audioldm/audio/stft.py @@ -0,0 +1,180 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy.signal import get_window +from librosa.util import pad_center, tiny +from librosa.filters import mel as librosa_mel_fn + +from audioldm.audio.audio_processing import ( + dynamic_range_compression, + dynamic_range_decompression, + window_sumsquare, +) + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + + def __init__(self, filter_length, hop_length, win_length, window="hann"): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) + + if window is not None: + assert filter_length >= win_length + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, size=filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode="reflect", + ) + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + torch.autograd.Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ).cpu() + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + torch.autograd.Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0] + ) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False + ) + window_sum = window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] + inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class TacotronSTFT(torch.nn.Module): + def __init__( + self, + filter_length, + hop_length, + win_length, + n_mel_channels, + sampling_rate, + mel_fmin, + mel_fmax, + ): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + mel_basis = librosa_mel_fn( + sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + + def spectral_normalize(self, magnitudes, normalize_fun): + output = dynamic_range_compression(magnitudes, normalize_fun) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y, normalize_fun=torch.log): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert torch.min(y.data) >= -1, torch.min(y.data) + assert torch.max(y.data) <= 1, torch.max(y.data) + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output, normalize_fun) + energy = torch.norm(magnitudes, dim=1) + + log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) + + return mel_output, log_magnitudes, energy diff --git a/audioldm/audio/tools.py b/audioldm/audio/tools.py new file mode 100755 index 0000000000000000000000000000000000000000..d641a982664b6673822c8528a1929c593f011b11 --- /dev/null +++ b/audioldm/audio/tools.py @@ -0,0 +1,85 @@ +import torch +import numpy as np +import torchaudio + + +def get_mel_from_wav(audio, _stft): + audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) + audio = torch.autograd.Variable(audio, requires_grad=False) + melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) + melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) + log_magnitudes_stft = ( + torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32) + ) + energy = torch.squeeze(energy, 0).numpy().astype(np.float32) + return melspec, log_magnitudes_stft, energy + + +def _pad_spec(fbank, target_length=1024): + n_frames = fbank.shape[0] + p = target_length - n_frames + # cut and pad + if p > 0: + m = torch.nn.ZeroPad2d((0, 0, 0, p)) + fbank = m(fbank) + elif p < 0: + fbank = fbank[0:target_length, :] + + if fbank.size(-1) % 2 != 0: + fbank = fbank[..., :-1] + + return fbank + + +def pad_wav(waveform, segment_length): + waveform_length = waveform.shape[-1] + assert waveform_length > 100, "Waveform is too short, %s" % waveform_length + if segment_length is None or waveform_length == segment_length: + return waveform + elif waveform_length > segment_length: + return waveform[:segment_length] + elif waveform_length < segment_length: + temp_wav = np.zeros((1, segment_length)) + temp_wav[:, :waveform_length] = waveform + return temp_wav + +def normalize_wav(waveform): + waveform = waveform - np.mean(waveform) + waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) + return waveform * 0.5 + + +def read_wav_file(filename, segment_length): + # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower + waveform, sr = torchaudio.load(filename) # Faster!!! + waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) + waveform = waveform.numpy()[0, ...] + waveform = normalize_wav(waveform) + waveform = waveform[None, ...] + waveform = pad_wav(waveform, segment_length) + + waveform = waveform / np.max(np.abs(waveform)) + waveform = 0.5 * waveform + + return waveform + + +def wav_to_fbank(filename, target_length=1024, fn_STFT=None): + assert fn_STFT is not None + + # mixup + waveform = read_wav_file(filename, target_length * 160) # hop size is 160 + + waveform = waveform[0, ...] + waveform = torch.FloatTensor(waveform) + + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + + fbank = torch.FloatTensor(fbank.T) + log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T) + + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( + log_magnitudes_stft, target_length + ) + + return fbank, log_magnitudes_stft, waveform diff --git a/audioldm/clap/__init__.py b/audioldm/clap/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm/clap/encoders.py b/audioldm/clap/encoders.py new file mode 100755 index 0000000000000000000000000000000000000000..d82d57fc6ea7b6c74aa38db2af47abccbd4e9a00 --- /dev/null +++ b/audioldm/clap/encoders.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn +from audioldm.clap.open_clip import create_model +from audioldm.clap.training.data import get_audio_features +import torchaudio +from transformers import RobertaTokenizer +import torch.nn.functional as F + + +class CLAPAudioEmbeddingClassifierFreev2(nn.Module): + def __init__( + self, + pretrained_path="", + key="class", + sampling_rate=16000, + embed_mode="audio", + amodel = "HTSAT-tiny", + unconditional_prob=0.1, + random_mute=False, + max_random_mute_portion=0.5, + training_mode=True, + ): + super().__init__() + + self.key = key + self.device = "cpu" + self.precision = "fp32" + self.amodel = amodel # or 'PANN-14' + self.tmodel = "roberta" # the best text encoder in our training + self.enable_fusion = False # False if you do not want to use the fusion model + self.fusion_type = "aff_2d" + self.pretrained = pretrained_path + self.embed_mode = embed_mode + self.embed_mode_orig = embed_mode + self.sampling_rate = sampling_rate + self.unconditional_prob = unconditional_prob + self.random_mute = random_mute + self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") + self.max_random_mute_portion = max_random_mute_portion + self.training_mode = training_mode + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device=self.device, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + for p in self.model.parameters(): + p.requires_grad = False + + self.model.eval() + + def get_unconditional_condition(self, batchsize): + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) + + def batch_to_list(self, batch): + ret = [] + for i in range(batch.size(0)): + ret.append(batch[i]) + return ret + + def make_decision(self, probability): + if float(torch.rand(1)) < probability: + return True + else: + return False + + def random_uniform(self, start, end): + val = torch.rand(1).item() + return start + (end - start) * val + + def _random_mute(self, waveform): + # waveform: [bs, t-steps] + t_steps = waveform.size(-1) + for i in range(waveform.size(0)): + mute_size = int( + self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) + ) + mute_start = int(self.random_uniform(0, t_steps - mute_size)) + waveform[i, mute_start : mute_start + mute_size] = 0 + return waveform + + def cos_similarity(self, waveform, text): + # waveform: [bs, t_steps] + with torch.no_grad(): + self.embed_mode = "audio" + print(text) + audio_emb = self(waveform.cuda()) + self.embed_mode = "text" + text_emb = self(text) + similarity = F.cosine_similarity(audio_emb, text_emb, dim=2) + return similarity.squeeze() + + def forward(self, batch, key=None): + # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 + # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 + if self.model.training == True and not self.training_mode: + print( + "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." + ) + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device="cuda", + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + for p in self.model.parameters(): + p.requires_grad = False + self.model.eval() + + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + if self.embed_mode == "audio": + with torch.no_grad(): + audio_dict_list = [] + assert ( + self.sampling_rate == 16000 + ), "We only support 16000 sampling rate" + if self.random_mute: + batch = self._random_mute(batch) + # batch: [bs, 1, t-samples] + batch = torchaudio.functional.resample( + batch, orig_freq=self.sampling_rate, new_freq=48000 + ) + for waveform in self.batch_to_list(batch): + audio_dict = {} + audio_dict = get_audio_features( + audio_dict, + waveform, + 480000, + data_truncating="fusion", + data_filling="repeatpad", + audio_cfg=self.model_cfg["audio_cfg"], + ) + audio_dict_list.append(audio_dict) + # [bs, 512] + embed = self.model.get_audio_embedding(audio_dict_list) + elif self.embed_mode == "text": + with torch.no_grad(): + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + text_data = self.tokenizer(batch) + embed = self.model.get_text_embedding(text_data) + + embed = embed.unsqueeze(1) + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + + for i in range(embed.size(0)): + if self.make_decision(self.unconditional_prob): + embed[i] = self.unconditional_token + + # [bs, 1, 512] + return embed.detach() + + def tokenizer(self, text): + result = self.tokenize( + text, + padding="max_length", + truncation=True, + max_length=512, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} diff --git a/audioldm/clap/open_clip/__init__.py b/audioldm/clap/open_clip/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e9f728f2f273be5d5fdbec6c6cc41d737176a8c0 --- /dev/null +++ b/audioldm/clap/open_clip/__init__.py @@ -0,0 +1,25 @@ +from .factory import ( + list_models, + create_model, + create_model_and_transforms, + add_model_config, +) +from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics +from .model import ( + CLAP, + CLAPTextCfg, + CLAPVisionCfg, + CLAPAudioCfp, + convert_weights_to_fp16, + trace_model, +) +from .openai import load_openai_model, list_openai_models +from .pretrained import ( + list_pretrained, + list_pretrained_tag_models, + list_pretrained_model_tags, + get_pretrained_url, + download_pretrained, +) +from .tokenizer import SimpleTokenizer, tokenize +from .transform import image_transform diff --git a/audioldm/clap/open_clip/bert.py b/audioldm/clap/open_clip/bert.py new file mode 100755 index 0000000000000000000000000000000000000000..a83d96d2a77ed05198efc05837522bc88d2499cc --- /dev/null +++ b/audioldm/clap/open_clip/bert.py @@ -0,0 +1,40 @@ +from transformers import BertTokenizer, BertModel + +tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") +model = BertModel.from_pretrained("bert-base-uncased") +text = "Replace me by any text you'd like." + + +def bert_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output = model(**encoded_input) + return output + + +from transformers import RobertaTokenizer, RobertaModel + +tokenizer = RobertaTokenizer.from_pretrained("roberta-base") +model = RobertaModel.from_pretrained("roberta-base") +text = "Replace me by any text you'd like." + + +def Roberta_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output = model(**encoded_input) + return output + + +from transformers import BartTokenizer, BartModel + +tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") +model = BartModel.from_pretrained("facebook/bart-base") +text = "Replace me by any text you'd like." + + +def bart_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output = model(**encoded_input) + return output diff --git a/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz b/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100755 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/audioldm/clap/open_clip/factory.py b/audioldm/clap/open_clip/factory.py new file mode 100755 index 0000000000000000000000000000000000000000..6e943a50abc767dfd8e04f1c27ae4830332717cd --- /dev/null +++ b/audioldm/clap/open_clip/factory.py @@ -0,0 +1,279 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path + +import torch + +from .model import CLAP, convert_weights_to_fp16 +from .openai import load_openai_model +from .pretrained import get_pretrained_url, download_pretrained +from .transform import image_transform + +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs +CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache/audioldm") + + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = (".json",) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f"*{ext}")) + + for cf in config_files: + if os.path.basename(cf)[0] == ".": + continue # Ignore hidden files + + with open(cf, "r") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = { + k: v + for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) + } + + +_rescan_model_configs() # initial populate of model config registry + + +def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + if skip_params: + if next(iter(state_dict.items()))[0].startswith("module"): + state_dict = {k[7:]: v for k, v in state_dict.items()} + # for k in state_dict: + # if k.startswith('transformer'): + # v = state_dict.pop(k) + # state_dict['text_branch.' + k[12:]] = v + return state_dict + + +def create_model( + amodel_name: str, + tmodel_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + openai_model_cache_dir: str = os.path.expanduser(f"{CACHE_DIR}/clip"), + skip_params=True, + pretrained_audio: str = "", + pretrained_text: str = "", + enable_fusion: bool = False, + fusion_type: str = "None" + # pretrained_image: bool = False, +): + amodel_name = amodel_name.replace( + "/", "-" + ) # for callers using old naming with / in ViT names + pretrained_orig = pretrained + pretrained = pretrained.lower() + if pretrained == "openai": + if amodel_name in _MODEL_CONFIGS: + logging.info(f"Loading {amodel_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) + else: + logging.error( + f"Model config for {amodel_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {amodel_name} not found.") + + logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") + # Hard Code in model name + model_cfg["text_cfg"]["model_type"] = tmodel_name + model = load_openai_model( + "ViT-B-16", + model_cfg, + device=device, + jit=jit, + cache_dir=openai_model_cache_dir, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 + if precision == "amp" or precision == "fp32": + model = model.float() + else: + if amodel_name in _MODEL_CONFIGS: + logging.info(f"Loading {amodel_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) + else: + logging.error( + f"Model config for {amodel_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {amodel_name} not found.") + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + # if pretrained_image: + # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): + # # pretrained weight loading for timm models set via vision_cfg + # model_cfg['vision_cfg']['timm_model_pretrained'] = True + # else: + # assert False, 'pretrained image towers currently only supported for timm models' + model_cfg["text_cfg"]["model_type"] = tmodel_name + model_cfg["enable_fusion"] = enable_fusion + model_cfg["fusion_type"] = fusion_type + model = CLAP(**model_cfg) + + if pretrained: + checkpoint_path = "" + url = get_pretrained_url(amodel_name, pretrained) + if url: + checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) + elif os.path.exists(pretrained_orig): + checkpoint_path = pretrained_orig + if checkpoint_path: + logging.info( + f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})." + ) + ckpt = load_state_dict(checkpoint_path, skip_params=True) + model.load_state_dict(ckpt) + param_names = [n for n, p in model.named_parameters()] + # for n in param_names: + # print(n, "\t", "Loaded" if n in ckpt else "Unloaded") + else: + logging.warning( + f"Pretrained weights ({pretrained}) not found for model {amodel_name}." + ) + raise RuntimeError( + f"Pretrained weights ({pretrained}) not found for model {amodel_name}." + ) + + if pretrained_audio: + if amodel_name.startswith("PANN"): + if "Cnn14_mAP" in pretrained_audio: # official checkpoint + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["model"] + keys = list(audio_ckpt.keys()) + for key in keys: + if ( + "spectrogram_extractor" not in key + and "logmel_extractor" not in key + ): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key] = v + elif os.path.basename(pretrained_audio).startswith( + "PANN" + ): # checkpoint trained via HTSAT codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model"): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "finetuned" + ): # checkpoint trained via linear probe codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + else: + raise ValueError("Unknown audio checkpoint") + elif amodel_name.startswith("HTSAT"): + if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model") and ( + "spectrogram_extractor" not in key + and "logmel_extractor" not in key + ): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "HTSAT" + ): # checkpoint trained via HTSAT codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model"): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "finetuned" + ): # checkpoint trained via linear probe codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + else: + raise ValueError("Unknown audio checkpoint") + else: + raise f"this audio encoder pretrained checkpoint is not support" + + model.load_state_dict(audio_ckpt, strict=False) + logging.info( + f"Loading pretrained {amodel_name} weights ({pretrained_audio})." + ) + param_names = [n for n, p in model.named_parameters()] + for n in param_names: + print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") + + model.to(device=device) + if precision == "fp16": + assert device.type != "cpu" + convert_weights_to_fp16(model) + + if jit: + model = torch.jit.script(model) + + return model, model_cfg + + +def create_model_and_transforms( + model_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + # pretrained_image: bool = False, +): + model = create_model( + model_name, + pretrained, + precision, + device, + jit, + force_quick_gelu=force_quick_gelu, + # pretrained_image=pretrained_image + ) + preprocess_train = image_transform(model.visual.image_size, is_train=True) + preprocess_val = image_transform(model.visual.image_size, is_train=False) + return model, preprocess_train, preprocess_val + + +def list_models(): + """enumerate available model architectures based on config files""" + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """add model config path or file and update registry""" + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() diff --git a/audioldm/clap/open_clip/feature_fusion.py b/audioldm/clap/open_clip/feature_fusion.py new file mode 100755 index 0000000000000000000000000000000000000000..dbe4e170e05894c12ebdc36ba1dc1de65e441b89 --- /dev/null +++ b/audioldm/clap/open_clip/feature_fusion.py @@ -0,0 +1,192 @@ +""" +Feature Fusion for Varible-Length Data Processing +AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py +According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 +""" + +import torch +import torch.nn as nn + + +class DAF(nn.Module): + """ + 直接相加 DirectAddFuse + """ + + def __init__(self): + super(DAF, self).__init__() + + def forward(self, x, residual): + return x + residual + + +class iAFF(nn.Module): + """ + 多特征融合 iAFF + """ + + def __init__(self, channels=64, r=4, type="2D"): + super(iAFF, self).__init__() + inter_channels = int(channels // r) + + if type == "1D": + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + elif type == "2D": + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + else: + raise f"the type is not supported" + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = torch.cat([xa, xa], dim=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xi = x * wei + residual * (1 - wei) + + xl2 = self.local_att2(xi) + xg2 = self.global_att(xi) + xlg2 = xl2 + xg2 + wei2 = self.sigmoid(xlg2) + xo = x * wei2 + residual * (1 - wei2) + if flag: + xo = xo[0].unsqueeze(0) + return xo + + +class AFF(nn.Module): + """ + 多特征融合 AFF + """ + + def __init__(self, channels=64, r=4, type="2D"): + super(AFF, self).__init__() + inter_channels = int(channels // r) + + if type == "1D": + self.local_att = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + elif type == "2D": + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + else: + raise f"the type is not supported." + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = torch.cat([xa, xa], dim=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xo = 2 * x * wei + 2 * residual * (1 - wei) + if flag: + xo = xo[0].unsqueeze(0) + return xo diff --git a/audioldm/clap/open_clip/htsat.py b/audioldm/clap/open_clip/htsat.py new file mode 100755 index 0000000000000000000000000000000000000000..3b856c6a43df162116a941f1b5c76e93713b276a --- /dev/null +++ b/audioldm/clap/open_clip/htsat.py @@ -0,0 +1,1308 @@ +# Ke Chen +# knutchen@ucsd.edu +# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION +# Some layers designed on the model +# below codes are based and referred from https://github.com/microsoft/Swin-Transformer +# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf + +import torch +import torch.nn as nn +import torch.nn.functional as F +from itertools import repeat +import collections.abc +import math +import warnings + +from torch.nn.init import _calculate_fan_in_and_fan_out +import torch.utils.checkpoint as checkpoint + +import random + +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from itertools import repeat +from .utils import do_mixup, interpolate + +from .feature_fusion import iAFF, AFF, DAF + +# from PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + patch_stride=16, + enable_fusion=False, + fusion_type="None", + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patch_stride = to_2tuple(patch_stride) + self.img_size = img_size + self.patch_size = patch_size + self.patch_stride = patch_stride + self.grid_size = ( + img_size[0] // patch_stride[0], + img_size[1] // patch_stride[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + padding = ( + (patch_size[0] - patch_stride[0]) // 2, + (patch_size[1] - patch_stride[1]) // 2, + ) + + if (self.enable_fusion) and (self.fusion_type == "channel_map"): + self.proj = nn.Conv2d( + in_chans * 4, + embed_dim, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + else: + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + self.mel_conv2d = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=(patch_size[0], patch_size[1] * 3), + stride=(patch_stride[0], patch_stride[1] * 3), + padding=padding, + ) + if self.fusion_type == "daf_2d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_2d": + self.fusion_model = AFF(channels=embed_dim, type="2D") + elif self.fusion_type == "iaff_2d": + self.fusion_model = iAFF(channels=embed_dim, type="2D") + + def forward(self, x, longer_idx=None): + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + global_x = x[:, 0:1, :, :] + + # global processing + B, C, H, W = global_x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + global_x = self.proj(global_x) + TW = global_x.size(-1) + if len(longer_idx) > 0: + # local processing + local_x = x[longer_idx, 1:, :, :].contiguous() + B, C, H, W = local_x.shape + local_x = local_x.view(B * C, 1, H, W) + local_x = self.mel_conv2d(local_x) + local_x = local_x.view( + B, C, local_x.size(1), local_x.size(2), local_x.size(3) + ) + local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3) + TB, TC, TH, _ = local_x.size() + if local_x.size(-1) < TW: + local_x = torch.cat( + [ + local_x, + torch.zeros( + (TB, TC, TH, TW - local_x.size(-1)), + device=global_x.device, + ), + ], + dim=-1, + ) + else: + local_x = local_x[:, :, :, :TW] + + global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x) + x = global_x + else: + B, C, H, W = x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}" + + +# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model +class SwinTransformerBlock(nn.Module): + r"""Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + norm_before_mlp="ln", + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.norm_before_mlp = norm_before_mlp + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + if self.norm_before_mlp == "ln": + self.norm2 = nn.LayerNorm(dim) + elif self.norm_before_mlp == "bn": + self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose( + 1, 2 + ) + else: + raise NotImplementedError + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + # pdb.set_trace() + H, W = self.input_resolution + # print("H: ", H) + # print("W: ", W) + # pdb.set_trace() + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows, attn = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x, attn + + def extra_repr(self): + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) + + +class PatchMerging(nn.Module): + r"""Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self): + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + norm_before_mlp="ln", + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + norm_before_mlp=norm_before_mlp, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer + ) + else: + self.downsample = None + + def forward(self, x): + attns = [] + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x, attn = blk(x) + if not self.training: + attns.append(attn.unsqueeze(0)) + if self.downsample is not None: + x = self.downsample(x) + if not self.training: + attn = torch.cat(attns, dim=0) + attn = torch.mean(attn, dim=0) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +# The Core of HTSAT +class HTSAT_Swin_Transformer(nn.Module): + r"""HTSAT based on the Swin Transformer + Args: + spec_size (int | tuple(int)): Input Spectrogram size. Default 256 + patch_size (int | tuple(int)): Patch size. Default: 4 + path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 + in_chans (int): Number of input image channels. Default: 1 (mono) + num_classes (int): Number of classes for classification head. Default: 527 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 8 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + config (module): The configuration Module from config.py + """ + + def __init__( + self, + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + in_chans=1, + num_classes=527, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + norm_before_mlp="ln", + config=None, + enable_fusion=False, + fusion_type="None", + **kwargs, + ): + super(HTSAT_Swin_Transformer, self).__init__() + + self.config = config + self.spec_size = spec_size + self.patch_stride = patch_stride + self.patch_size = patch_size + self.window_size = window_size + self.embed_dim = embed_dim + self.depths = depths + self.ape = ape + self.in_chans = in_chans + self.num_classes = num_classes + self.num_heads = num_heads + self.num_layers = len(self.depths) + self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) + + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + + self.qkv_bias = qkv_bias + self.qk_scale = None + + self.patch_norm = patch_norm + self.norm_layer = norm_layer if self.patch_norm else None + self.norm_before_mlp = norm_before_mlp + self.mlp_ratio = mlp_ratio + + self.use_checkpoint = use_checkpoint + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # process mel-spec ; used only once + self.freq_ratio = self.spec_size // self.config.mel_bins + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + self.interpolate_ratio = 32 # Downsampled ratio + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=config.window_size, + hop_length=config.hop_size, + win_length=config.window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=config.sample_rate, + n_fft=config.window_size, + n_mels=config.mel_bins, + fmin=config.fmin, + fmax=config.fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) # 2 2 + self.bn0 = nn.BatchNorm2d(self.config.mel_bins) + + # split spctrogram into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=self.spec_size, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + norm_layer=self.norm_layer, + patch_stride=patch_stride, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.grid_size + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dim) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=self.drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(self.embed_dim * 2**i_layer), + input_resolution=( + patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer), + ), + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_scale=self.qk_scale, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[ + sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1]) + ], + norm_layer=self.norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + norm_before_mlp=self.norm_before_mlp, + ) + self.layers.append(layer) + + self.norm = self.norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.maxpool = nn.AdaptiveMaxPool1d(1) + + SF = ( + self.spec_size + // (2 ** (len(self.depths) - 1)) + // self.patch_stride[0] + // self.freq_ratio + ) + self.tscam_conv = nn.Conv2d( + in_channels=self.num_features, + out_channels=self.num_classes, + kernel_size=(SF, 3), + padding=(0, 1), + ) + self.head = nn.Linear(num_classes, num_classes) + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] + ): + self.mel_conv1d = nn.Sequential( + nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1d(64), + ) + if self.fusion_type == "daf_1d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_1d": + self.fusion_model = AFF(channels=64, type="1D") + elif self.fusion_type == "iaff_1d": + self.fusion_model = iAFF(channels=64, type="1D") + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"absolute_pos_embed"} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"relative_position_bias_table"} + + def forward_features(self, x, longer_idx=None): + # A deprecated optimization for using a hierarchical output from different blocks + + frames_num = x.shape[2] + x = self.patch_embed(x, longer_idx=longer_idx) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + for i, layer in enumerate(self.layers): + x, attn = layer(x) + # for x + x = self.norm(x) + B, N, C = x.shape + SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] + ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] + x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST) + B, C, F, T = x.shape + # group 2D CNN + c_freq_bin = F // self.freq_ratio + x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) + x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1) + # get latent_output + fine_grained_latent_output = torch.mean(x, dim=2) + fine_grained_latent_output = interpolate( + fine_grained_latent_output.permute(0, 2, 1).contiguous(), + 8 * self.patch_stride[1], + ) + + latent_output = self.avgpool(torch.flatten(x, 2)) + latent_output = torch.flatten(latent_output, 1) + + # display the attention map, if needed + + x = self.tscam_conv(x) + x = torch.flatten(x, 2) # B, C, T + + fpx = interpolate( + torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1] + ) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + output_dict = { + "framewise_output": fpx, # already sigmoided + "clipwise_output": torch.sigmoid(x), + "fine_grained_embedding": fine_grained_latent_output, + "embedding": latent_output, + } + + return output_dict + + def crop_wav(self, x, crop_size, spe_pos=None): + time_steps = x.shape[2] + tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) + for i in range(len(x)): + if spe_pos is None: + crop_pos = random.randint(0, time_steps - crop_size - 1) + else: + crop_pos = spe_pos + tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :] + return tx + + # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model + def reshape_wav2img(self, x): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert ( + T <= target_T and F <= target_F + ), "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate( + x, (target_T, x.shape[3]), mode="bicubic", align_corners=True + ) + if F < target_F: + x = nn.functional.interpolate( + x, (x.shape[2], target_F), mode="bicubic", align_corners=True + ) + x = x.permute(0, 1, 3, 2).contiguous() + x = x.reshape( + x.shape[0], + x.shape[1], + x.shape[2], + self.freq_ratio, + x.shape[3] // self.freq_ratio, + ) + # print(x.shape) + x = x.permute(0, 1, 3, 2, 4).contiguous() + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) + return x + + # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model + def repeat_wat2img(self, x, cur_pos): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert ( + T <= target_T and F <= target_F + ), "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate( + x, (target_T, x.shape[3]), mode="bicubic", align_corners=True + ) + if F < target_F: + x = nn.functional.interpolate( + x, (x.shape[2], target_F), mode="bicubic", align_corners=True + ) + x = x.permute(0, 1, 3, 2).contiguous() # B C F T + x = x[:, :, :, cur_pos : cur_pos + self.spec_size] + x = x.repeat(repeats=(1, 1, 4, 1)) + return x + + def forward( + self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None + ): # out_feat_keys: List[str] = None): + + if self.enable_fusion and x["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True + + if not self.enable_fusion: + x = x["waveform"].to(device=device, non_blocking=True) + x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.training: + x = self.spec_augmenter(x) + + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x) + else: + longer_list = x["longer"].to(device=device, non_blocking=True) + x = x["mel_fusion"].to(device=device, non_blocking=True) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + longer_list_idx = torch.where(longer_list)[0] + if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: + new_x = x[:, 0:1, :, :].clone().contiguous() + if len(longer_list_idx) > 0: + # local processing + fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() + FB, FC, FT, FF = fusion_x_local.size() + fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) + fusion_x_local = torch.permute( + fusion_x_local, (0, 2, 1) + ).contiguous() + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.view( + FB, FC, FF, fusion_x_local.size(-1) + ) + fusion_x_local = ( + torch.permute(fusion_x_local, (0, 2, 1, 3)) + .contiguous() + .flatten(2) + ) + if fusion_x_local.size(-1) < FT: + fusion_x_local = torch.cat( + [ + fusion_x_local, + torch.zeros( + (FB, FF, FT - fusion_x_local.size(-1)), + device=device, + ), + ], + dim=-1, + ) + else: + fusion_x_local = fusion_x_local[:, :, :FT] + # 1D fusion + new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() + new_x[longer_list_idx] = self.fusion_model( + new_x[longer_list_idx], fusion_x_local + ) + x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] + else: + x = new_x + + elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x, longer_idx=longer_list_idx) + + # if infer_mode: + # # in infer mode. we need to handle different length audio input + # frame_num = x.shape[2] + # target_T = int(self.spec_size * self.freq_ratio) + # repeat_ratio = math.floor(target_T / frame_num) + # x = x.repeat(repeats=(1,1,repeat_ratio,1)) + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # else: + # if x.shape[2] > self.freq_ratio * self.spec_size: + # if self.training: + # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # else: + # # Change: Hard code here + # overlap_size = (x.shape[2] - 1) // 4 + # output_dicts = [] + # crop_size = (x.shape[2] - 1) // 2 + # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): + # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) + # tx = self.reshape_wav2img(tx) + # output_dicts.append(self.forward_features(tx)) + # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) + # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) + # for d in output_dicts: + # clipwise_output += d["clipwise_output"] + # framewise_output += d["framewise_output"] + # clipwise_output = clipwise_output / len(output_dicts) + # framewise_output = framewise_output / len(output_dicts) + # output_dict = { + # 'framewise_output': framewise_output, + # 'clipwise_output': clipwise_output + # } + # else: # this part is typically used, and most easy one + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # x = self.head(x) + + # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T + + return output_dict + + +def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"): + try: + + assert audio_cfg.model_name in [ + "tiny", + "base", + "large", + ], "model name for HTS-AT is wrong!" + if audio_cfg.model_name == "tiny": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + elif audio_cfg.model_name == "base": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=128, + depths=[2, 2, 12, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + elif audio_cfg.model_name == "large": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=256, + depths=[2, 2, 12, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + + return model + except: + raise RuntimeError( + f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." + ) diff --git a/audioldm/clap/open_clip/linear_probe.py b/audioldm/clap/open_clip/linear_probe.py new file mode 100755 index 0000000000000000000000000000000000000000..9d7e23b6b67a53e16d050d675a99d01d7d04d581 --- /dev/null +++ b/audioldm/clap/open_clip/linear_probe.py @@ -0,0 +1,66 @@ +import numpy as np +import torch.nn.functional as F +from torch import nn +from .model import MLPLayers + + +class LinearProbe(nn.Module): + def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): + """ + Args: + model: nn.Module + mlp: bool, if True, then use the MLP layer as the linear probe module + freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe + in_ch: int, the output channel from CLAP model + out_ch: int, the output channel from linear probe (class_num) + act: torch.nn.functional, the activation function before the loss function + """ + super().__init__() + in_ch = 512 + self.clap_model = model + self.clap_model.text_branch = None # to save memory + self.freeze = freeze + if mlp: + self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) + else: + self.lp_layer = nn.Linear(in_ch, out_ch) + + if self.freeze: + for param in self.clap_model.parameters(): + param.requires_grad = False + + if act == "None": + self.act = None + elif act == "relu": + self.act = nn.ReLU() + elif act == "elu": + self.act = nn.ELU() + elif act == "prelu": + self.act = nn.PReLU(num_parameters=in_ch) + elif act == "softmax": + self.act = nn.Softmax(dim=-1) + elif act == "sigmoid": + self.act = nn.Sigmoid() + + def forward(self, x, mix_lambda=None, device=None): + """ + Args: + x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list + mix_lambda: torch.tensor [batch], the mixup lambda + Returns: + class_prob: torch.tensor [batch, class_num] + + """ + # batchnorm cancel grandient + if self.freeze: + self.clap_model.eval() + + x = self.clap_model.audio_projection( + self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ + "embedding" + ] + ) + out = self.lp_layer(x) + if self.act is not None: + out = self.act(out) + return out diff --git a/audioldm/clap/open_clip/loss.py b/audioldm/clap/open_clip/loss.py new file mode 100755 index 0000000000000000000000000000000000000000..cc66298a14997da4aa2efc71e37c0a6bcda53fd1 --- /dev/null +++ b/audioldm/clap/open_clip/loss.py @@ -0,0 +1,398 @@ +from multiprocessing.sharedctypes import Value +import torch +import torch.distributed.nn +from torch import distributed as dist, nn as nn +from torch.nn import functional as F +import numpy as np +from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def gather_features( + audio_features, + text_features, + audio_features_mlp=None, + text_features_mlp=None, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False, + mlp_loss=False, +): + if use_horovod: + assert hvd is not None, "Please install horovod" + if gather_with_grad: + all_audio_features = hvd.allgather(audio_features) + all_text_features = hvd.allgather(text_features) + if mlp_loss: + all_audio_features_mlp = hvd.allgather(audio_features_mlp) + all_text_features_mlp = hvd.allgather(text_features_mlp) + else: + with torch.no_grad(): + all_audio_features = hvd.allgather(audio_features) + all_text_features = hvd.allgather(text_features) + if mlp_loss: + all_audio_features_mlp = hvd.allgather(audio_features_mlp) + all_text_features_mlp = hvd.allgather(text_features_mlp) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features = list( + all_audio_features.chunk(world_size, dim=0) + ) + gathered_text_features = list( + all_text_features.chunk(world_size, dim=0) + ) + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + if mlp_loss: + gathered_audio_features_mlp = list( + all_audio_features_mlp.chunk(world_size, dim=0) + ) + gathered_text_features_mlp = list( + all_text_features_mlp.chunk(world_size, dim=0) + ) + gathered_audio_features_mlp[rank] = audio_features_mlp + gathered_text_features_mlp[rank] = text_features_mlp + all_audio_features_mlp = torch.cat( + gathered_audio_features_mlp, dim=0 + ) + all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_audio_features = torch.cat( + torch.distributed.nn.all_gather(audio_features), dim=0 + ) + all_text_features = torch.cat( + torch.distributed.nn.all_gather(text_features), dim=0 + ) + if mlp_loss: + all_audio_features_mlp = torch.cat( + torch.distributed.nn.all_gather(audio_features_mlp), dim=0 + ) + all_text_features_mlp = torch.cat( + torch.distributed.nn.all_gather(text_features_mlp), dim=0 + ) + else: + gathered_audio_features = [ + torch.zeros_like(audio_features) for _ in range(world_size) + ] + gathered_text_features = [ + torch.zeros_like(text_features) for _ in range(world_size) + ] + dist.all_gather(gathered_audio_features, audio_features) + dist.all_gather(gathered_text_features, text_features) + if mlp_loss: + gathered_audio_features_mlp = [ + torch.zeros_like(audio_features_mlp) for _ in range(world_size) + ] + gathered_text_features_mlp = [ + torch.zeros_like(text_features_mlp) for _ in range(world_size) + ] + dist.all_gather(gathered_audio_features_mlp, audio_features_mlp) + dist.all_gather(gathered_text_features_mlp, text_features_mlp) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + if mlp_loss: + gathered_audio_features_mlp[rank] = audio_features_mlp + gathered_text_features_mlp[rank] = text_features_mlp + + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + if mlp_loss: + all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) + all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) + if mlp_loss: + return ( + all_audio_features, + all_text_features, + all_audio_features_mlp, + all_text_features_mlp, + ) + else: + return all_audio_features, all_text_features + + +class ClipLoss(nn.Module): + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + mlp_loss=False, + weight_loss_kappa=0, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + self.mlp_loss = mlp_loss + self.weighted_loss = bool(weight_loss_kappa != 0) + self.weight_loss_kappa = weight_loss_kappa + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward( + self, + audio_features, + text_features, + logit_scale_a, + logit_scale_t=None, + audio_features_mlp=None, + text_features_mlp=None, + ): + device = audio_features.device + if self.mlp_loss: + if self.world_size > 1: + ( + all_audio_features, + all_text_features, + all_audio_features_mlp, + all_text_features_mlp, + ) = gather_features( + audio_features=audio_features, + text_features=text_features, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + use_horovod=self.use_horovod, + mlp_loss=self.mlp_loss, + ) + if self.local_loss: + a_logits_per_audio = ( + logit_scale_a * audio_features @ all_text_features_mlp.T + ) + a_logits_per_text = ( + logit_scale_a * text_features_mlp @ all_audio_features.T + ) + t_logits_per_audio = ( + logit_scale_t * audio_features_mlp @ all_text_features.T + ) + t_logits_per_text = ( + logit_scale_t * text_features @ all_audio_features_mlp.T + ) + else: + a_logits_per_audio = ( + logit_scale_a * all_audio_features @ all_text_features_mlp.T + ) + a_logits_per_text = a_logits_per_audio.T + t_logits_per_audio = ( + logit_scale_t * all_audio_features_mlp @ all_text_features.T + ) + t_logits_per_text = t_logits_per_audio.T + else: + a_logits_per_audio = ( + logit_scale_a * audio_features @ text_features_mlp.T + ) + a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T + t_logits_per_audio = ( + logit_scale_t * audio_features_mlp @ text_features.T + ) + t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T + + # calculated ground-truth and cache if enabled + num_logits = a_logits_per_audio.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + if not self.weighted_loss: + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels) + + F.cross_entropy(a_logits_per_text, labels) + + F.cross_entropy(t_logits_per_audio, labels) + + F.cross_entropy(t_logits_per_text, labels) + ) / 4 + else: + audio_weight = (audio_features @ audio_features.T).detach() + audio_weight = ( + torch.exp( + torch.sum(audio_weight, axis=1) + / (self.weight_loss_kappa * len(audio_weight)) + ) + ).detach() + text_weight = (text_features @ text_features.T).detach() + text_weight = ( + torch.exp( + torch.sum(text_weight, axis=1) + / (self.weight_loss_kappa * len(text_features)) + ) + ).detach() + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) + + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) + + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) + + F.cross_entropy(t_logits_per_text, labels, weight=text_weight) + ) / 4 + else: + if self.world_size > 1: + all_audio_features, all_text_features = gather_features( + audio_features=audio_features, + text_features=text_features, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + use_horovod=self.use_horovod, + mlp_loss=self.mlp_loss, + ) + + if self.local_loss: + logits_per_audio = ( + logit_scale_a * audio_features @ all_text_features.T + ) + logits_per_text = ( + logit_scale_a * text_features @ all_audio_features.T + ) + else: + logits_per_audio = ( + logit_scale_a * all_audio_features @ all_text_features.T + ) + logits_per_text = logits_per_audio.T + else: + logits_per_audio = logit_scale_a * audio_features @ text_features.T + logits_per_text = logit_scale_a * text_features @ audio_features.T + + # calculated ground-truth and cache if enabled + num_logits = logits_per_audio.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + if not self.weighted_loss: + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + else: + audio_weight = (all_audio_features @ all_audio_features.T).detach() + audio_weight = ( + torch.exp( + torch.sum(audio_weight, axis=1) + / (self.weight_loss_kappa * len(all_audio_features)) + ) + ).detach() + text_weight = (all_text_features @ all_text_features.T).detach() + text_weight = ( + torch.exp( + torch.sum(text_weight, axis=1) + / (self.weight_loss_kappa * len(all_text_features)) + ) + ).detach() + total_loss = ( + F.cross_entropy(logits_per_audio, labels, weight=text_weight) + + F.cross_entropy(logits_per_text, labels, weight=audio_weight) + ) / 2 + return total_loss + + +def lp_gather_features(pred, target, world_size=1, use_horovod=False): + if use_horovod: + assert hvd is not None, "Please install horovod" + with torch.no_grad(): + all_preds = hvd.allgather(pred) + all_targets = hvd.allgath(target) + else: + gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] + gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] + + dist.all_gather(gathered_preds, pred) + dist.all_gather(gathered_targets, target) + all_preds = torch.cat(gathered_preds, dim=0) + all_targets = torch.cat(gathered_targets, dim=0) + + return all_preds, all_targets + + +def get_map(pred, target): + pred = torch.sigmoid(pred).numpy() + target = target.numpy() + return np.mean(average_precision_score(target, pred, average=None)) + + +def get_acc(pred, target): + pred = torch.argmax(pred, 1).numpy() + target = torch.argmax(target, 1).numpy() + return accuracy_score(target, pred) + + +def get_mauc(pred, target): + pred = torch.sigmoid(pred).numpy() + target = target.numpy() + return np.mean(roc_auc_score(target, pred, average=None)) + + +class LPMetrics(object): + def __init__(self, metric_names=["map", "acc", "mauc"]): + self.metrics = [] + for name in metric_names: + self.metrics.append(self.get_metric(name)) + self.metric_names = metric_names + + def get_metric(self, name): + if name == "map": + return get_map + elif name == "acc": + return get_acc + elif name == "mauc": + return get_mauc + else: + raise ValueError(f"the metric should be at least one of [map, acc, mauc]") + + def evaluate_mertics(self, pred, target): + metric_dict = {} + for i in range(len(self.metric_names)): + metric_dict[self.metric_names[i]] = self.metrics[i](pred, target) + return metric_dict + + +def calc_celoss(pred, target): + target = torch.argmax(target, 1).long() + return nn.CrossEntropyLoss()(pred, target) + + +class LPLoss(nn.Module): + def __init__(self, loss_name): + super().__init__() + if loss_name == "bce": + self.loss_func = nn.BCEWithLogitsLoss() + elif loss_name == "ce": + self.loss_func = calc_celoss + elif loss_name == "mse": + self.loss_func = nn.MSELoss() + else: + raise ValueError(f"the loss func should be at least one of [bce, ce, mse]") + + def forward(self, pred, target): + loss = self.loss_func(pred, target) + return loss diff --git a/audioldm/clap/open_clip/model.py b/audioldm/clap/open_clip/model.py new file mode 100755 index 0000000000000000000000000000000000000000..b439244f8c293a0b4263b7ac1fd553e9d0adf184 --- /dev/null +++ b/audioldm/clap/open_clip/model.py @@ -0,0 +1,936 @@ +""" CLAP Model + +Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +Adapted to the Audio Task. +""" + +from collections import OrderedDict +from dataclasses import dataclass +from email.mime import audio +from typing import Tuple, Union, Callable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from .timm_model import TimmModel +import logging +from .utils import freeze_batch_norm_2d + +from .pann_model import create_pann_model +from .htsat import create_htsat_model +from transformers import BertModel, RobertaModel, BartModel +from transformers.tokenization_utils_base import BatchEncoding + + +class MLPLayers(nn.Module): + def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1): + super(MLPLayers, self).__init__() + self.nonlin = nonlin + self.dropout = dropout + + sequence = [] + for u0, u1 in zip(units[:-1], units[1:]): + sequence.append(nn.Linear(u0, u1)) + sequence.append(self.nonlin) + sequence.append(nn.Dropout(self.dropout)) + sequence = sequence[:-2] + + self.sequential = nn.Sequential(*sequence) + + def forward(self, X): + X = self.sequential(X) + return X + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( + 2, 0, 1 + ) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + def stem(self, x): + for conv, bn in [ + (self.conv1, self.bn1), + (self.conv2, self.bn2), + (self.conv3, self.bn3), + ]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock(width, heads, act_layer=act_layer) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisualTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + act_layer: Callable = nn.GELU, + ): + super().__init__() + self.image_size = image_size + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn((image_size // patch_size) ** 2 + 1, width) + ) + self.ln_pre = LayerNorm(width) + + self.text_branch = Transformer(width, layers, heads, act_layer=act_layer) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_branch(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +@dataclass +class CLAPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + timm_model_name: str = ( + None # a valid model name overrides layers, width, patch_size + ) + timm_model_pretrained: bool = ( + False # use (imagenet) pretrained weights for named model + ) + timm_pool: str = ( + "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + ) + timm_proj: str = ( + "linear" # linear projection for timm model output ('linear', 'mlp', '') + ) + + +# Audio Config Class +@dataclass +class CLAPAudioCfp: + model_type: str = "PANN" + model_name: str = "Cnn14" + sample_rate: int = 48000 + # Param + audio_length: int = 1024 + window_size: int = 1024 + hop_size: int = 1024 + fmin: int = 50 + fmax: int = 14000 + class_num: int = 527 + mel_bins: int = 64 + clip_samples: int = 480000 + + +@dataclass +class CLAPTextCfg: + context_length: int + vocab_size: int + width: int + heads: int + layers: int + model_type: str + + +class CLAP(nn.Module): + def __init__( + self, + embed_dim: int, + audio_cfg: CLAPAudioCfp, + text_cfg: CLAPTextCfg, + quick_gelu: bool = False, + enable_fusion: bool = False, + fusion_type: str = "None", + joint_embed_shape: int = 512, + mlp_act: str = "relu", + ): + super().__init__() + if isinstance(audio_cfg, dict): + audio_cfg = CLAPAudioCfp(**audio_cfg) + if isinstance(text_cfg, dict): + text_cfg = CLAPTextCfg(**text_cfg) + + self.audio_cfg = audio_cfg + self.text_cfg = text_cfg + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + self.joint_embed_shape = joint_embed_shape + self.mlp_act = mlp_act + + self.context_length = text_cfg.context_length + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if mlp_act == "relu": + mlp_act_layer = nn.ReLU() + elif mlp_act == "gelu": + mlp_act_layer = nn.GELU() + else: + raise NotImplementedError + + # audio branch + # audio branch parameters + if audio_cfg.model_type == "PANN": + self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type) + elif audio_cfg.model_type == "HTSAT": + self.audio_branch = create_htsat_model( + audio_cfg, enable_fusion, fusion_type + ) + else: + logging.error(f"Model config for {audio_cfg.model_type} not found") + raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.") + + # text branch + # text branch parameters + if text_cfg.model_type == "transformer": + self.text_branch = Transformer( + width=text_cfg.width, + layers=text_cfg.layers, + heads=text_cfg.heads, + act_layer=act_layer, + ) + self.vocab_size = text_cfg.vocab_size + self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, text_cfg.width) + ) + self.ln_final = LayerNorm(text_cfg.width) + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(text_cfg.width, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "bert": + self.text_branch = BertModel.from_pretrained("bert-base-uncased") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "roberta": + self.text_branch = RobertaModel.from_pretrained("roberta-base") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "bart": + self.text_branch = BartModel.from_pretrained("facebook/bart-base") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + else: + logging.error(f"Model config for {text_cfg.model_type} not found") + raise RuntimeError(f"Model config for {text_cfg.model_type} not found.") + self.text_branch_type = text_cfg.model_type + # text branch parameters + + # audio branch parameters + self.audio_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + + # below here is text branch parameters + + # ============================================================================================================ + self.audio_projection = nn.Sequential( + nn.Linear(embed_dim, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + + self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) + + self.init_text_branch_parameters() + + def init_text_branch_parameters(self): + if self.text_branch_type == "transformer": + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + proj_std = (self.text_branch.width**-0.5) * ( + (2 * self.text_branch.layers) ** -0.5 + ) + attn_std = self.text_branch.width**-0.5 + fc_std = (2 * self.text_branch.width) ** -0.5 + for block in self.text_branch.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + if self.text_branch_type == "bert" or self.text_branch_type == "roberta": + width = self.text_branch.embeddings.word_embeddings.weight.shape[-1] + elif self.text_branch_type == "bart": + width = self.text_branch.shared.weight.shape[-1] + else: + width = self.text_branch.width + nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07)) + nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07)) + + # deprecated + # if hasattr(self.visual, 'init_parameters'): + # self.visual.init_parameters() + + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_audio(self, audio, device): + return self.audio_branch( + audio, mixup_lambda=None, device=device + ) # mix lambda needs to add + + # def list_of_dict_of_tensor2dict_of_tensor(self, x, device): + # tmp = {} + # for k in x[0].keys(): + # tmp[k] = [] + # for i in range(len(x)): + # tmp[k].append(x[i][k][:77]) + # for k in x[0].keys(): + # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True) + # return tmp + + def encode_text(self, text, device): + if self.text_branch_type == "transformer": + text = text.to(device=device, non_blocking=True) + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_branch(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)]) + elif self.text_branch_type == "bert": + # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device) + # text = BatchEncoding(text) + x = self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + token_type_ids=text["token_type_ids"].to( + device=device, non_blocking=True + ), + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "roberta": + x = self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "bart": + x = torch.mean( + self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + )["encoder_last_hidden_state"], + axis=1, + ) + x = self.text_projection(x) + else: + logging.error(f"Model type {self.text_branch_type} not found") + raise RuntimeError(f"Model type {self.text_branch_type} not found.") + return x + + def forward(self, audio, text, device=None): + """Forward audio and text into the CLAP + + Parameters + ---------- + audio: torch.Tensor (batch_size, audio_length) + the time-domain audio input / the batch of mel_spec and longer list. + text: torch.Tensor () // need to add + the text token input + """ + if device is None: + if audio is not None: + device = audio.device + elif text is not None: + device = text.device + if audio is None and text is None: + # a hack to get the logit scale + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + elif audio is None: + return self.encode_text(text, device=device) + elif text is None: + return self.audio_projection( + self.encode_audio(audio, device=device)["embedding"] + ) + audio_features = self.audio_projection( + self.encode_audio(audio, device=device)["embedding"] + ) + audio_features = F.normalize(audio_features, dim=-1) + + text_features = self.encode_text(text, device=device) + # print("text_features", text_features) + # print("text_features.shape", text_features.shape) + # print("text_features.type", type(text_features)) + text_features = F.normalize(text_features, dim=-1) + + audio_features_mlp = self.audio_transform(audio_features) + text_features_mlp = self.text_transform(text_features) + # Four outputs: audio features (basic & MLP), text features (basic & MLP) + return ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + self.logit_scale_a.exp(), + self.logit_scale_t.exp(), + ) + + def get_logit_scale(self): + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + + def get_text_embedding(self, data): + """Get the text embedding from the model + + Parameters + ---------- + data: torch.Tensor + a tensor of text embedding + + Returns + ---------- + text_embed: torch.Tensor + a tensor of text_embeds (N, D) + + """ + device = next(self.parameters()).device + for k in data: + data[k] = data[k].to(device) + if len(data[k].size()) < 2: + data[k] = data[k].unsqueeze(0) + text_embeds = self.encode_text(data, device=device) + text_embeds = F.normalize(text_embeds, dim=-1) + + return text_embeds + + def get_audio_embedding(self, data): + """Get the audio embedding from the model + + Parameters + ---------- + data: a list of dict + the audio input dict list from 'get_audio_feature' method + + Returns + ---------- + audio_embed: torch.Tensor + a tensor of audio_embeds (N, D) + + """ + device = next(self.parameters()).device + input_dict = {} + keys = data[0].keys() + for k in keys: + input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to( + device + ) + + audio_embeds = self.audio_projection( + self.encode_audio(input_dict, device=device)["embedding"] + ) + audio_embeds = F.normalize(audio_embeds, dim=-1) + + return audio_embeds + + def audio_infer(self, audio, hopsize=None, device=None): + """Forward one audio and produce the audio embedding + + Parameters + ---------- + audio: (audio_length) + the time-domain audio input, notice that it must be only one input + hopsize: int + the overlap hopsize as the sliding window + + Returns + ---------- + output_dict: { + key: [n, (embedding_shape)] if "HTS-AT" + or + key: [(embedding_shape)] if "PANN" + } + the list of key values of the audio branch + + """ + + assert not self.training, "the inference mode must be run at eval stage" + output_dict = {} + # PANN + if self.audio_cfg.model_type == "PANN": + audio_input = audio.unsqueeze(dim=0) + output_dict[key] = self.encode_audio(audio_input, device=device)[ + key + ].squeeze(dim=0) + elif self.audio_cfg.model_type == "HTSAT": + # repeat + audio_len = len(audio) + k = self.audio_cfg.clip_samples // audio_len + if k > 1: + audio = audio.repeat(k) + audio_len = len(audio) + + if hopsize is None: + hopsize = min(hopsize, audio_len) + + if audio_len > self.audio_cfg.clip_samples: + audio_input = [ + audio[pos : pos + self.audio_cfg.clip_samples].clone() + for pos in range( + 0, audio_len - self.audio_cfg.clip_samples, hopsize + ) + ] + audio_input.append(audio[-self.audio_cfg.clip_samples :].clone()) + audio_input = torch.stack(audio_input) + output_dict[key] = self.encode_audio(audio_input, device=device)[key] + else: + audio_input = audio.unsqueeze(dim=0) + output_dict[key] = self.encode_audio(audio_input, device=device)[ + key + ].squeeze(dim=0) + + return output_dict + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [ + *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", + "bias_k", + "bias_v", + ]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +# Ignore the state dict of the vision part +def build_model_from_openai_state_dict( + state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None" +): + + embed_dim = model_cfg["embed_dim"] + audio_cfg = model_cfg["audio_cfg"] + text_cfg = model_cfg["text_cfg"] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"transformer.resblocks") + ) + ) + + audio_cfg = CLAPAudioCfp(**audio_cfg) + text_cfg = CLAPTextCfg(**text_cfg) + + model = CLAP( + embed_dim, + audio_cfg=audio_cfg, + text_cfg=text_cfg, + quick_gelu=True, # OpenAI models were trained with QuickGELU + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + state_dict["logit_scale_a"] = state_dict["logit_scale"] + state_dict["logit_scale_t"] = state_dict["logit_scale"] + pop_keys = list(state_dict.keys())[::] + # pop the visual branch saved weights + for key in pop_keys: + if key.startswith("visual."): + state_dict.pop(key, None) + + for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + # not use fp16 + # convert_weights_to_fp16(model) + model.load_state_dict(state_dict, strict=False) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device("cpu")): + model.eval() + audio_length = model.audio_cfg.audio_length + example_audio = torch.ones((batch_size, audio_length), device=device) + example_text = torch.zeros( + (batch_size, model.context_length), dtype=torch.int, device=device + ) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_audio, example_text), + encode_text=(example_text,), + encode_image=(example_audio,), + ), + ) + model.audio_cfg.audio_length = audio_length # Question: what does this do? + return model diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-base.json b/audioldm/clap/open_clip/model_configs/HTSAT-base.json new file mode 100755 index 0000000000000000000000000000000000000000..6cef625a89daf4431f1c9f72e10bc9640eef2ba8 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/HTSAT-base.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 1024, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "base" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-large.json b/audioldm/clap/open_clip/model_configs/HTSAT-large.json new file mode 100755 index 0000000000000000000000000000000000000000..699cdb1b16855582606551e4196b24aba2ffd871 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/HTSAT-large.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "large" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json b/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json new file mode 100755 index 0000000000000000000000000000000000000000..73e42990fe8361a0df502e7f93d29f19f58c9ecb --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1536, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "tiny" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json b/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json new file mode 100755 index 0000000000000000000000000000000000000000..a6e7821163d9afa81c27345a1e472475b92af169 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "tiny" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-10.json b/audioldm/clap/open_clip/model_configs/PANN-10.json new file mode 100755 index 0000000000000000000000000000000000000000..954ddf62921aed7dde9c37ffffec98a2e96a4ee7 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-10.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 1024, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn10" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json new file mode 100755 index 0000000000000000000000000000000000000000..b7989bc0cd95d0d39049b7524eba508b3e386439 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 18000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json new file mode 100755 index 0000000000000000000000000000000000000000..56bdb56bedc304ffa52d8bf5988cea2c1d82d14e --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 960000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 360, + "fmin": 50, + "fmax": 8000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json b/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json new file mode 100755 index 0000000000000000000000000000000000000000..5756e3bebc97cc985f512cb081930fee4e49bec1 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 4 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json b/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json new file mode 100755 index 0000000000000000000000000000000000000000..5a9e7e208b661619d5e26625e849da1adda8a475 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1536, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14.json b/audioldm/clap/open_clip/model_configs/PANN-14.json new file mode 100755 index 0000000000000000000000000000000000000000..39a5134cde1d8c50f4758377c952ef22f07bab41 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-6.json b/audioldm/clap/open_clip/model_configs/PANN-6.json new file mode 100755 index 0000000000000000000000000000000000000000..21ebc344326de260c386ba77e0ad63cf9b04febf --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-6.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 512, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn6" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json b/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json new file mode 100755 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN101.json b/audioldm/clap/open_clip/model_configs/RN101.json new file mode 100755 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json b/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json new file mode 100755 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/audioldm/clap/open_clip/model_configs/RN50.json b/audioldm/clap/open_clip/model_configs/RN50.json new file mode 100755 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN50x16.json b/audioldm/clap/open_clip/model_configs/RN50x16.json new file mode 100755 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN50x4.json b/audioldm/clap/open_clip/model_configs/RN50x4.json new file mode 100755 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-16.json b/audioldm/clap/open_clip/model_configs/ViT-B-16.json new file mode 100755 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json b/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json new file mode 100755 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-32.json b/audioldm/clap/open_clip/model_configs/ViT-B-32.json new file mode 100755 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-L-14.json b/audioldm/clap/open_clip/model_configs/ViT-L-14.json new file mode 100755 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/openai.py b/audioldm/clap/open_clip/openai.py new file mode 100755 index 0000000000000000000000000000000000000000..fcb624f54a8b9d2c4b11e3adb50c53c3261716d4 --- /dev/null +++ b/audioldm/clap/open_clip/openai.py @@ -0,0 +1,159 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import Union, List + +import torch + +from .model import build_model_from_openai_state_dict +from .pretrained import ( + get_pretrained_url, + list_pretrained_tag_models, + download_pretrained, +) + +__all__ = ["list_openai_models", "load_openai_model"] + +CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache") + + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_tag_models("openai") + + +def load_openai_model( + name: str, + model_cfg, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit=True, + cache_dir=os.path.expanduser(f"{CACHE_DIR}/clip"), + enable_fusion: bool = False, + fusion_type: str = "None", +): + """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + + Returns + ------- + model : torch.nn.Module + The CLAP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if get_pretrained_url(name, "openai"): + model_path = download_pretrained( + get_pretrained_url(name, "openai"), root=cache_dir + ) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError( + f"Model {name} not found; available models = {list_openai_models()}" + ) + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn( + f"File {model_path} is not a JIT archive. Loading as a state dict instead" + ) + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + try: + model = build_model_from_openai_state_dict( + state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type + ).to(device) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict( + sd, model_cfg, enable_fusion, fusion_type + ).to(device) + + if str(device) == "cpu": + model.float() + return model + + # patch the device names + device_holder = torch.jit.trace( + lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] + ) + device_node = [ + n + for n in device_holder.graph.findAllNodes("prim::Constant") + if "Device" in repr(n) + ][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith( + "cuda" + ): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_audio) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace( + lambda: torch.ones([]).float(), example_inputs=[] + ) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [ + 1, + 2, + ]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_audio) + patch_float(model.encode_text) + model.float() + + model.audio_branch.audio_length = model.audio_cfg.audio_length + return model diff --git a/audioldm/clap/open_clip/pann_model.py b/audioldm/clap/open_clip/pann_model.py new file mode 100755 index 0000000000000000000000000000000000000000..0d9a8eb0bf897ad6ec04923361b01e5de433b2ef --- /dev/null +++ b/audioldm/clap/open_clip/pann_model.py @@ -0,0 +1,704 @@ +# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition +# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn +# Some layers are re-designed for CLAP +import os + +os.environ["NUMBA_CACHE_DIR"] = "/tmp/" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from .utils import do_mixup, interpolate, pad_framewise_output +from .feature_fusion import iAFF, AFF, DAF + + +def init_layer(layer): + """Initialize a Linear or Convolutional layer.""" + nn.init.xavier_uniform_(layer.weight) + + if hasattr(layer, "bias"): + if layer.bias is not None: + layer.bias.data.fill_(0.0) + + +def init_bn(bn): + """Initialize a Batchnorm layer.""" + bn.bias.data.fill_(0.0) + bn.weight.data.fill_(1.0) + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + + def forward(self, input, pool_size=(2, 2), pool_type="avg"): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + x = F.relu_(self.bn2(self.conv2(x))) + if pool_type == "max": + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg": + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg+max": + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception("Incorrect argument!") + + return x + + +class ConvBlock5x5(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock5x5, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(5, 5), + stride=(1, 1), + padding=(2, 2), + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_bn(self.bn1) + + def forward(self, input, pool_size=(2, 2), pool_type="avg"): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + if pool_type == "max": + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg": + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg+max": + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception("Incorrect argument!") + + return x + + +class AttBlock(nn.Module): + def __init__(self, n_in, n_out, activation="linear", temperature=1.0): + super(AttBlock, self).__init__() + + self.activation = activation + self.temperature = temperature + self.att = nn.Conv1d( + in_channels=n_in, + out_channels=n_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + self.cla = nn.Conv1d( + in_channels=n_in, + out_channels=n_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + + self.bn_att = nn.BatchNorm1d(n_out) + self.init_weights() + + def init_weights(self): + init_layer(self.att) + init_layer(self.cla) + init_bn(self.bn_att) + + def forward(self, x): + # x: (n_samples, n_in, n_time) + norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) + cla = self.nonlinear_transform(self.cla(x)) + x = torch.sum(norm_att * cla, dim=2) + return x, norm_att, cla + + def nonlinear_transform(self, x): + if self.activation == "linear": + return x + elif self.activation == "sigmoid": + return torch.sigmoid(x) + + +class Cnn14(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + + super(Cnn14, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + if (self.enable_fusion) and (self.fusion_type == "channel_map"): + self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) + else: + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] + ): + self.mel_conv1d = nn.Sequential( + nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1d(64), # No Relu + ) + if self.fusion_type == "daf_1d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_1d": + self.fusion_model = AFF(channels=64, type="1D") + elif self.fusion_type == "iaff_1d": + self.fusion_model = iAFF(channels=64, type="1D") + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + self.mel_conv2d = nn.Sequential( + nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + ) + + if self.fusion_type == "daf_2d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_2d": + self.fusion_model = AFF(channels=64, type="2D") + elif self.fusion_type == "iaff_2d": + self.fusion_model = iAFF(channels=64, type="2D") + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + if self.enable_fusion and input["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True + + if not self.enable_fusion: + x = self.spectrogram_extractor( + input["waveform"].to(device=device, non_blocking=True) + ) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + else: + longer_list = input["longer"].to(device=device, non_blocking=True) + x = input["mel_fusion"].to(device=device, non_blocking=True) + longer_list_idx = torch.where(longer_list)[0] + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: + new_x = x[:, 0:1, :, :].clone().contiguous() + # local processing + if len(longer_list_idx) > 0: + fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() + FB, FC, FT, FF = fusion_x_local.size() + fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) + fusion_x_local = torch.permute( + fusion_x_local, (0, 2, 1) + ).contiguous() + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.view( + FB, FC, FF, fusion_x_local.size(-1) + ) + fusion_x_local = ( + torch.permute(fusion_x_local, (0, 2, 1, 3)) + .contiguous() + .flatten(2) + ) + if fusion_x_local.size(-1) < FT: + fusion_x_local = torch.cat( + [ + fusion_x_local, + torch.zeros( + (FB, FF, FT - fusion_x_local.size(-1)), + device=device, + ), + ], + dim=-1, + ) + else: + fusion_x_local = fusion_x_local[:, :, :FT] + # 1D fusion + new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() + new_x[longer_list_idx] = self.fusion_model( + new_x[longer_list_idx], fusion_x_local + ) + x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] + else: + x = new_x + elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + global_x = x[:, 0:1, :, :] + + # global processing + B, C, H, W = global_x.shape + global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg") + if len(longer_list_idx) > 0: + local_x = x[longer_list_idx, 1:, :, :].contiguous() + TH = global_x.size(-2) + # local processing + B, C, H, W = local_x.shape + local_x = local_x.view(B * C, 1, H, W) + local_x = self.mel_conv2d(local_x) + local_x = local_x.view( + B, C, local_x.size(1), local_x.size(2), local_x.size(3) + ) + local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3) + TB, TC, _, TW = local_x.size() + if local_x.size(-2) < TH: + local_x = torch.cat( + [ + local_x, + torch.zeros( + (TB, TC, TH - local_x.size(-2), TW), + device=global_x.device, + ), + ], + dim=-2, + ) + else: + local_x = local_x[:, :, :TH, :] + + global_x[longer_list_idx] = self.fusion_model( + global_x[longer_list_idx], local_x + ) + x = global_x + else: + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 32) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + return output_dict + + +class Cnn6(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + + super(Cnn6, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) + + self.fc1 = nn.Linear(512, 512, bias=True) + self.fc_audioset = nn.Linear(512, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 16) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + + return output_dict + + +class Cnn10(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + + super(Cnn10, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + + self.fc1 = nn.Linear(1024, 1024, bias=True) + self.fc_audioset = nn.Linear(1024, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 32) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + + return output_dict + + +def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"): + try: + ModelProto = eval(audio_cfg.model_name) + model = ModelProto( + sample_rate=audio_cfg.sample_rate, + window_size=audio_cfg.window_size, + hop_size=audio_cfg.hop_size, + mel_bins=audio_cfg.mel_bins, + fmin=audio_cfg.fmin, + fmax=audio_cfg.fmax, + classes_num=audio_cfg.class_num, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + return model + except: + raise RuntimeError( + f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." + ) diff --git a/audioldm/clap/open_clip/pretrained.py b/audioldm/clap/open_clip/pretrained.py new file mode 100755 index 0000000000000000000000000000000000000000..8ed2ae1732a28c4e98d1f3412157ef27054e41dc --- /dev/null +++ b/audioldm/clap/open_clip/pretrained.py @@ -0,0 +1,169 @@ +import hashlib +import os +import urllib +import warnings + +from tqdm import tqdm + +CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache") + +_RN50 = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", +) + +_RN50_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", +) + +_RN101 = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", +) + +_RN101_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", +) + +_RN50x4 = dict( + openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", +) + +_RN50x16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", +) + +_RN50x64 = dict( + openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", +) + +_VITB32 = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB32_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", +) + +_VITL14 = dict( + openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", +) + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-L-14": _VITL14, +} + + +def list_pretrained(as_str: bool = False): + """returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [ + ":".join([k, t]) if as_str else (k, t) + for k in _PRETRAINED.keys() + for t in _PRETRAINED[k].keys() + ] + + +def list_pretrained_tag_models(tag: str): + """return all models having the specified pretrain tag""" + models = [] + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_model_tags(model: str): + """return all pretrain tags for the specified model architecture""" + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def get_pretrained_url(model: str, tag: str): + if model not in _PRETRAINED: + return "" + model_pretrained = _PRETRAINED[model] + if tag not in model_pretrained: + return "" + return model_pretrained[tag] + + +def download_pretrained(url: str, root: str = os.path.expanduser(f"{CACHE_DIR}/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + if "openaipublic" in url: + expected_sha256 = url.split("/")[-2] + else: + expected_sha256 = "" + + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if ( + hashlib.sha256(open(download_target, "rb").read()).hexdigest() + == expected_sha256 + ): + return download_target + else: + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if ( + expected_sha256 + and hashlib.sha256(open(download_target, "rb").read()).hexdigest() + != expected_sha256 + ): + raise RuntimeError( + f"Model has been downloaded but the SHA256 checksum does not not match" + ) + + return download_target diff --git a/audioldm/clap/open_clip/timm_model.py b/audioldm/clap/open_clip/timm_model.py new file mode 100755 index 0000000000000000000000000000000000000000..c9d1ab4666b5bab5038d44b90c9ddca5087de460 --- /dev/null +++ b/audioldm/clap/open_clip/timm_model.py @@ -0,0 +1,112 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +from collections import OrderedDict + +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import ( + AttentionPool2d as AbsAttentionPool2d, + ) +except ImportError as e: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool="avg", + proj="linear", + drop=0.0, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + self.trunk = timm.create_model(model_name, pretrained=pretrained) + feat_size = self.trunk.default_cfg.get("pool_size", None) + feature_ndim = 1 if not feat_size else 2 + if pool in ("abs_attn", "rot_attn"): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool="") + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == "abs_attn": + head_layers["pool"] = AbsAttentionPool2d( + prev_chs, feat_size=feat_size, out_features=embed_dim + ) + prev_chs = embed_dim + elif pool == "rot_attn": + head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, "projection layer needed if non-attention pooling is used." + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == "linear": + head_layers["drop"] = nn.Dropout(drop) + head_layers["proj"] = nn.Linear(prev_chs, embed_dim) + elif proj == "mlp": + head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" + ) + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/audioldm/clap/open_clip/tokenizer.py b/audioldm/clap/open_clip/tokenizer.py new file mode 100755 index 0000000000000000000000000000000000000000..ee4d28450ec5dd12a79daf38cf3088e9e73c2cd5 --- /dev/null +++ b/audioldm/clap/open_clip/tokenizer.py @@ -0,0 +1,197 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" + ) + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + if not special_tokens: + special_tokens = ["", ""] + else: + special_tokens = ["", ""] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + +_tokenizer = SimpleTokenizer() + + +def tokenize( + texts: Union[str, List[str]], context_length: int = 77 +) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + result[i, : len(tokens)] = torch.tensor(tokens) + + return result diff --git a/audioldm/clap/open_clip/transform.py b/audioldm/clap/open_clip/transform.py new file mode 100755 index 0000000000000000000000000000000000000000..77aaa722c4a5544ac50de6df35d3e922f63b111d --- /dev/null +++ b/audioldm/clap/open_clip/transform.py @@ -0,0 +1,45 @@ +from torchvision.transforms import ( + Normalize, + Compose, + RandomResizedCrop, + InterpolationMode, + ToTensor, + Resize, + CenterCrop, +) + + +def _convert_to_rgb(image): + return image.convert("RGB") + + +def image_transform( + image_size: int, + is_train: bool, + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), +): + normalize = Normalize(mean=mean, std=std) + if is_train: + return Compose( + [ + RandomResizedCrop( + image_size, + scale=(0.9, 1.0), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) + else: + return Compose( + [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) diff --git a/audioldm/clap/open_clip/utils.py b/audioldm/clap/open_clip/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..34ecbced4cb7e6b6f92154a666e2c7efc7c922c6 --- /dev/null +++ b/audioldm/clap/open_clip/utils.py @@ -0,0 +1,362 @@ +import numpy as np +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d +import logging + +# import h5py +from tqdm import tqdm +import random +import json +import os +import pathlib + +# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later. +dataset_split = { + "audiocaps": ["train", "valid", "test"], + "audioset": ["balanced_train", "unbalanced_train", "eval"], + "BBCSoundEffects": ["train", "test"], + "Clotho": ["train", "test", "valid"], + "free_to_use_sounds": ["train", "test"], + "paramount_motion": ["train", "test"], + "sonniss_game_effects": ["train", "test"], + "wesoundeffects": ["train", "test"], + "MACS": ["train", "test"], + "freesound": ["train", "test"], + "FSD50K": ["train", "test", "valid"], + "fsd50k_class_label": ["train", "test", "valid"], + "esc50": ["train", "test"], + "audiostock": ["train", "test"], + "freesound_no_overlap_noesc50": ["train", "test"], + "epidemic_sound_effects": ["train", "test"], + "VGGSound": ["train", "test"], + "urbansound8k_class_label": ["train", "test"], + "audioset_t5": ["balanced_train", "unbalanced_train", "eval"], + "epidemic_sound_effects_t5": ["train", "test"], + "WavText5K": ["train", "test"], + "esc50_no_overlap": ["train", "test"], + "usd8k_no_overlap": ["train", "test"], + "fsd50k_200_class_label": ["train", "test", "valid"], +} + + +def freeze_batch_norm_2d(module, module_match={}, name=""): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance( + module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) + ): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = ".".join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +def exist(dataset_name, dataset_type): + """ + Check if dataset exists + """ + if dataset_type in dataset_split[dataset_name]: + return True + else: + return False + + +def get_tar_path_from_dataset_name( + dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None +): + """ + Get tar path from dataset name and type + """ + output = [] + for n in dataset_names: + if full_dataset is not None and n in full_dataset: + current_dataset_types = dataset_split[n] + else: + current_dataset_types = dataset_types + for s in current_dataset_types: + tmp = [] + if islocal: + sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json" + if not os.path.exists(sizefilepath_): + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + else: + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + if not os.path.exists(sizefilepath_): + continue + sizes = json.load(open(sizefilepath_, "r")) + for k in sizes.keys(): + if islocal: + tmp.append(f"{dataset_path}/{n}/{s}/{k}") + else: + tmp.append( + f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -" + ) + if proportion != 1: + tmp = random.sample(tmp, int(proportion * len(tmp))) + output.append(tmp) + return sum(output, []) + + +def get_tar_path_from_txts(txt_path, islocal, proportion=1): + """ + Get tar path from txt path + """ + if isinstance(txt_path, (list, tuple)): + return sum( + [ + get_tar_path_from_txts( + txt_path[i], islocal=islocal, proportion=proportion + ) + for i in range(len(txt_path)) + ], + [], + ) + if isinstance(txt_path, str): + with open(txt_path) as f: + lines = f.readlines() + if islocal: + lines = [ + lines[i] + .split("\n")[0] + .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/") + for i in range(len(lines)) + ] + else: + lines = [ + lines[i].split("\n")[0].replace(".tar", ".tar -") + for i in range(len(lines)) + ] + if proportion != 1: + print("Sampling tars with proportion of {}".format(proportion)) + lines = random.sample(lines, int(proportion * len(lines))) + return lines + + +def get_mix_lambda(mixup_alpha, batch_size): + mixup_lambdas = [ + np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size) + ] + return np.array(mixup_lambdas).astype(np.float32) + + +def do_mixup(x, mixup_lambda): + """ + Args: + x: (batch_size , ...) + mixup_lambda: (batch_size,) + Returns: + out: (batch_size, ...) + """ + out = ( + x.transpose(0, -1) * mixup_lambda + + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda) + ).transpose(0, -1) + return out + + +def interpolate(x, ratio): + """Interpolate data in time domain. This is used to compensate the + resolution reduction in downsampling of a CNN. + + Args: + x: (batch_size, time_steps, classes_num) + ratio: int, ratio to interpolate + Returns: + upsampled: (batch_size, time_steps * ratio, classes_num) + """ + (batch_size, time_steps, classes_num) = x.shape + upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) + upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) + return upsampled + + +def pad_framewise_output(framewise_output, frames_num): + """Pad framewise_output to the same length as input frames. The pad value + is the same as the value of the last frame. + Args: + framewise_output: (batch_size, frames_num, classes_num) + frames_num: int, number of frames to pad + Outputs: + output: (batch_size, frames_num, classes_num) + """ + pad = framewise_output[:, -1:, :].repeat( + 1, frames_num - framewise_output.shape[1], 1 + ) + """tensor for padding""" + + output = torch.cat((framewise_output, pad), dim=1) + """(batch_size, frames_num, classes_num)""" + + +# def process_ipc(index_path, classes_num, filename): +# # load data +# logging.info("Load Data...............") +# ipc = [[] for _ in range(classes_num)] +# with h5py.File(index_path, "r") as f: +# for i in tqdm(range(len(f["target"]))): +# t_class = np.where(f["target"][i])[0] +# for t in t_class: +# ipc[t].append(i) +# print(ipc) +# np.save(filename, ipc) +# logging.info("Load Data Succeed...............") + + +def save_to_dict(s, o_={}): + sp = s.split(": ") + o_.update({sp[0]: float(sp[1])}) + return o_ + + +def get_data_from_log(txt_path): + """ + Output dictionary from out.txt log file + """ + with open(txt_path) as f: + lines = f.readlines() + val_data = {} + train_data = {} + train_losses = [] + train_losses_epoch = [] + for i in range(len(lines)): + if "| INFO |" in lines[i]: + if "Eval Epoch" in lines[i]: + if "val_loss" in lines[i]: + # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", "")) + line = lines[i].split("Eval Epoch: ")[-1] + num_epoch = int(line.split(" ")[0].split(" ")[0]) + d = { + line.split(" ")[0] + .split(" ")[1] + .replace(":", ""): float(line.split(" ")[0].split(" ")[-1]) + } + for i in range(1, len(line.split(" "))): + d = save_to_dict(line.split(" ")[i], d) + val_data[num_epoch] = d + elif "Train Epoch" in lines[i]: + num_epoch = int(lines[i].split("Train Epoch: ")[1][0]) + loss = float(lines[i].split("Loss: ")[-1].split(" (")[0]) + train_losses.append(loss) + train_losses_epoch.append(num_epoch) + for i in range(len(train_losses)): + train_data[i] = { + "num_epoch": train_losses_epoch[i], + "train_loss": train_losses[i], + } + return train_data, val_data + + +def save_p(obj, filename): + import pickle + + try: + from deepdiff import DeepDiff + except: + os.system("pip install deepdiff") + from deepdiff import DeepDiff + with open(filename, "wb") as file: + pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol + with open(filename, "rb") as file: + z = pickle.load(file) + assert ( + DeepDiff(obj, z, ignore_string_case=True) == {} + ), "there is something wrong with the saving process" + return + + +def load_p(filename): + import pickle + + with open(filename, "rb") as file: + z = pickle.load(file) + return z + + +def save_json(data, name="data.json"): + import json + + with open(name, "w") as fp: + json.dump(data, fp) + return + + +def load_json(name): + import json + + with open(name, "r") as fp: + data = json.load(fp) + return data + + +from multiprocessing import Process, Manager +from multiprocessing import Process, Value, Array +from ctypes import c_wchar + + +def load_class_label(path): + # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing + # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array + out = None + if path is not None: + if pathlib.Path(path).suffix in [".pkl", ".pickle"]: + out = load_p(path) + elif pathlib.Path(path).suffix in [".json", ".txt"]: + out = load_json(path) + elif pathlib.Path(path).suffix in [".npy", ".npz"]: + out = np.load(path) + elif pathlib.Path(path).suffix in [".csv"]: + import pandas as pd + + out = pd.read_csv(path) + return out + # if out is None: + # return None + # else: + # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False) + # val = Array('i', out.values(), lock=False) + # return (key, val) + + +from torch import optim + + +def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): + if optimizer_name.lower() == "adamw": + optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps) + elif optimizer_name.lower() == "sgd": + optimizer = optim.SGD(params, lr=lr, momentum=momentum) + elif optimizer_name.lower() == "adam": + optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps) + else: + raise ValueError("optimizer name is not correct") + return optimizer diff --git a/audioldm/clap/open_clip/version.py b/audioldm/clap/open_clip/version.py new file mode 100755 index 0000000000000000000000000000000000000000..3ced3581bb601ae91b1e1da4b8f4f520855a065e --- /dev/null +++ b/audioldm/clap/open_clip/version.py @@ -0,0 +1 @@ +__version__ = "0.2.1" diff --git a/audioldm/clap/training/__init__.py b/audioldm/clap/training/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm/clap/training/audioset_textmap.npy b/audioldm/clap/training/audioset_textmap.npy new file mode 100755 index 0000000000000000000000000000000000000000..3da4c92d3819aaec11e5f576464a9973a6df811b --- /dev/null +++ b/audioldm/clap/training/audioset_textmap.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b +size 84448 diff --git a/audioldm/clap/training/data.py b/audioldm/clap/training/data.py new file mode 100755 index 0000000000000000000000000000000000000000..a005fee2f51e577446839b8cffd117d9ae93abc9 --- /dev/null +++ b/audioldm/clap/training/data.py @@ -0,0 +1,981 @@ +import ast +import json +import logging +import math +import os +import random + +# import h5py +from dataclasses import dataclass +from audioldm.clap.training.params import parse_args + +# import braceexpand +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.datasets as datasets +import torchvision.transforms + +# import webdataset as wds +from PIL import Image +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler +from torch.utils.data.distributed import DistributedSampler +from functools import partial +import soundfile as sf +import io +from pathlib import Path + +# import wget + +from audioldm.clap.open_clip.utils import ( + get_tar_path_from_dataset_name, + dataset_split, +) +from audioldm.clap.open_clip.utils import load_p, load_class_label +import copy + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +try: + import torchaudio +except ImportError: + torchaudio = None + +from audioldm.clap.open_clip import tokenize + + +def tokenizer(text): + return tokenize(text).squeeze(0) + + +from transformers import RobertaTokenizer + +tokenize = RobertaTokenizer.from_pretrained("roberta-base") + + +def tokenizer(text): + result = tokenize( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + + +# initizlied the audioset map +_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy") +_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1.0, a_max=1.0) + return (x * 32767.0).astype(np.int16) + + +# For Toy Dataset +# class ToyDataset(Dataset): +# def __init__(self, index_path, ipc, config, eval_mode=False): +# """Toy Dataset for testing the audioset input with text labels +# Parameters +# ---------- +# index_path: str +# the link to the h5 file of each audio +# idc: str +# the link to the npy file, the number of samples in each class +# config: dict +# the audio cfg file +# eval_model (bool): to indicate if the dataset is a testing dataset +# """ +# self.audio_cfg = config["audio_cfg"] +# self.text_cfg = config["text_cfg"] +# self.fp = h5py.File(index_path, "r") +# self.ipc = np.load(ipc, allow_pickle=True) +# self.total_size = len(self.fp["audio_name"]) +# self.classes_num = self.audio_cfg["class_num"] +# self.eval_mode = eval_mode + +# if not eval_mode: +# self.generate_queue() +# else: +# self.queue = [] +# for i in range(self.total_size): +# target = self.fp["target"][i] +# if np.sum(target) > 0: +# self.queue.append(i) +# self.total_size = len(self.queue) +# logging.info("total dataset size: %d" % (self.total_size)) +# logging.info("class num: %d" % (self.classes_num)) + +# def time_shifting(self, x): +# frame_num = len(x) +# shift_len = random.randint(0, frame_num - 1) +# new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0) +# return new_sample + +# def generate_queue(self): +# self.queue = [] +# while len(self.queue) < self.total_size: +# class_set = [*range(self.classes_num)] +# random.shuffle(class_set) +# self.queue += [ +# self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set +# ] +# self.queue = self.queue[: self.total_size] + +# logging.info("queue regenerated:%s" % (self.queue[-5:])) + +# def crop_wav(self, x): +# crop_size = self.audio_cfg["crop_size"] +# crop_pos = random.randint(0, len(x) - crop_size - 1) +# return x[crop_pos : crop_pos + crop_size] + +# def prompt_text(self, target): +# events = _AUDIOSET_MAP[np.where(target > 0)] +# event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1] +# text = tokenize(event_text)[0] +# return text + +# def __getitem__(self, index): +# """Load waveform, text, and target of an audio clip + +# Parameters +# ---------- +# index: int +# the index number +# Return +# ------ +# output: dict { +# "hdf5_path": str, +# "index_in_hdf5": int, +# "audio_name": str, +# "waveform": list (audio_length,), +# "target": list (class_num, ), +# "text": torch.tensor (context_length,) +# } +# the output dictionary +# """ +# s_index = self.queue[index] + +# audio_name = self.fp["audio_name"][s_index].decode() +# # Hardcode here CHANGE +# hdf5_path = ( +# self.fp["hdf5_path"][s_index] +# .decode() +# .replace( +# "../workspace", +# "/home/la/kechen/Research/ke_zsasp/workspace", +# ) +# ) +# r_idx = self.fp["index_in_hdf5"][s_index] +# target = self.fp["target"][s_index].astype(np.float32) +# text = self.prompt_text(target) +# with h5py.File(hdf5_path, "r") as f: +# waveform = int16_to_float32(f["waveform"][r_idx])[ +# : self.audio_cfg["clip_samples"] +# ] +# assert ( +# len(waveform) == self.audio_cfg["clip_samples"] +# ), "The sample length is not match" +# # Time shift +# # if (self.config.enable_time_shift) and (not self.eval_mode): +# # waveform = self.time_shifting(waveform) +# # # Label Enhance +# # if (self.config.crop_size is not None) and (not self.eval_mode): +# # waveform = self.crop_wav(waveform) +# # # the label enhance rate is fixed 0.5 +# # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5: +# # kidx = np.where(target)[0] +# # for k in kidx: +# # for add_key in self.class_map[k][1]: +# # target[add_key] = 1.0 +# # if len(self.class_map[k][2]) > 0: +# # add_key = random.choice(self.class_map[k][2]) +# # target[add_key] = 1.0 + +# # missing the text input +# mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :] +# mel_spec = ( +# torch.cat( +# [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0 +# ) +# .cpu() +# .numpy() +# ) +# longer = random.choice([True, False]) +# if longer == False: +# mel_spec[1:, :, :] = 0.0 +# data_dict = { +# "hdf5_path": hdf5_path, +# "index_in_hdf5": r_idx, +# "audio_name": audio_name, +# "waveform": waveform, +# "class_label": target, +# "text": text, +# "longer": longer, +# "mel_fusion": mel_spec, +# } +# return data_dict + +# def __len__(self): +# return self.total_size + + +class CsvDataset(Dataset): + def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"): + logging.debug(f"Loading csv data from {input_filename}.") + df = pd.read_csv(input_filename, sep=sep) + + self.images = df[img_key].tolist() + self.captions = df[caption_key].tolist() + self.transforms = transforms + logging.debug("Done loading data.") + + def __len__(self): + return len(self.captions) + + def __getitem__(self, idx): + images = self.transforms(Image.open(str(self.images[idx]))) + texts = tokenize([str(self.captions[idx])])[0] + return images, texts + + +@dataclass +class DataInfo: + dataloader: DataLoader + sampler: DistributedSampler + + +def preprocess_txt(text): + return tokenize([str(text)])[0] + + +def get_dataset_size(shards, sizefilepath_=None, is_local=True): + if isinstance(shards, list): + size_list = [] + for s in shards: + size_list.append( + get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0] + ) + else: + if not is_local: + for n in dataset_split.keys(): + if n in shards.split("/"): + break + for s in dataset_split[n]: + if s in shards.split("/"): + break + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + shards_list = list(braceexpand.braceexpand(shards)) + dir_path = os.path.dirname(shards) + if sizefilepath_ is not None: + sizes = json.load(open(sizefilepath_, "r")) + total_size = sum( + [ + int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))]) + for shard in shards_list + ] + ) + else: + sizes_filename = os.path.join(dir_path, "sizes.json") + len_filename = os.path.join(dir_path, "__len__") + if os.path.exists(sizes_filename): + sizes = json.load(open(sizes_filename, "r")) + total_size = sum( + [int(sizes[os.path.basename(shard)]) for shard in shards_list] + ) + elif os.path.exists(len_filename): + # FIXME this used to be eval(open(...)) but that seemed rather unsafe + total_size = ast.literal_eval(open(len_filename, "r").read()) + else: + raise Exception( + "Cannot find sizes file for dataset. Please specify the path to the file." + ) + # total_size = None # num samples undefined + # some common dataset sizes (at time of authors last download) + # cc3m-train: 2905954 + # cc12m: 10968539 + # LAION-400m: 407332084 + num_shards = len(shards_list) + if isinstance(shards, list): + return sum(size_list), len(shards) + else: + return total_size, num_shards + + +def get_imagenet(args, preprocess_fns, split): + assert split in ["train", "val", "v2"] + is_train = split == "train" + preprocess_train, preprocess_val = preprocess_fns + + if split == "v2": + from imagenetv2_pytorch import ImageNetV2Dataset + + dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) + else: + if is_train: + data_path = args.imagenet_train + preprocess_fn = preprocess_train + else: + data_path = args.imagenet_val + preprocess_fn = preprocess_val + assert data_path + + dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) + + if is_train: + idxs = np.zeros(len(dataset.targets)) + target_array = np.array(dataset.targets) + k = 50 + for c in range(1000): + m = target_array == c + n = len(idxs[m]) + arr = np.zeros(n) + arr[:k] = 1 + np.random.shuffle(arr) + idxs[m] = arr + + idxs = idxs.astype("int") + sampler = SubsetRandomSampler(np.where(idxs)[0]) + else: + sampler = None + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=sampler, + ) + + return DataInfo(dataloader, sampler) + + +def count_samples(dataloader): + os.environ["WDS_EPOCH"] = "0" + n_elements, n_batches = 0, 0 + for images, texts in dataloader: + n_batches += 1 + n_elements += len(images) + assert len(images) == len(texts) + return n_elements, n_batches + + +def filter_no_caption(sample): + return "txt" in sample + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") + return True + + +_SHARD_SHUFFLE_SIZE = 2000 +_SHARD_SHUFFLE_INITIAL = 500 +_SAMPLE_SHUFFLE_SIZE = 5000 +_SAMPLE_SHUFFLE_INITIAL = 1000 + + +def sample_prop(sizefile, inputs, proportion, is_local=True): + """ + Sample a proportion of the data. + """ + file_path_dict = { + os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0] + for i in range(len(inputs)) + } + sampled_filepath_dict = {} + sampled_size_dict = {} + if not is_local: + if os.path.exists("sizes.json"): + os.remove("sizes.json") + wget.download(sizefile, "sizes.json") + sizefile = "sizes.json" + with open(sizefile, "r", encoding="UTF-8") as f: + load_dict = json.load(f) + L = int(len(file_path_dict) * proportion) + subkeys = random.sample(file_path_dict.keys(), L) + for k in subkeys: + sampled_size_dict[k] = load_dict[k] + sampled_filepath_dict[k] = file_path_dict[k] + return ( + sum(sampled_size_dict.values()), + L, + [os.path.join(v, k) for k, v in sampled_filepath_dict.items()], + sampled_size_dict, + ) + + +def get_mel(audio_data, audio_cfg): + # mel shape: (n_mels, T) + mel = torchaudio.transforms.MelSpectrogram( + sample_rate=audio_cfg["sample_rate"], + n_fft=audio_cfg["window_size"], + win_length=audio_cfg["window_size"], + hop_length=audio_cfg["hop_size"], + center=True, + pad_mode="reflect", + power=2.0, + norm=None, + onesided=True, + n_mels=64, + f_min=audio_cfg["fmin"], + f_max=audio_cfg["fmax"], + ).to(audio_data.device) + mel = mel(audio_data) + # Align to librosa: + # librosa_melspec = librosa.feature.melspectrogram( + # waveform, + # sr=audio_cfg['sample_rate'], + # n_fft=audio_cfg['window_size'], + # hop_length=audio_cfg['hop_size'], + # win_length=audio_cfg['window_size'], + # center=True, + # pad_mode="reflect", + # power=2.0, + # n_mels=64, + # norm=None, + # htk=True, + # f_min=audio_cfg['fmin'], + # f_max=audio_cfg['fmax'] + # ) + # we use log mel spectrogram as input + mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) + return mel.T # (T, n_mels) + + +def get_audio_features( + sample, audio_data, max_len, data_truncating, data_filling, audio_cfg +): + """ + Calculate and add audio features to sample. + Sample: a dict containing all the data of current sample. + audio_data: a tensor of shape (T) containing audio data. + max_len: the maximum length of audio data. + data_truncating: the method of truncating data. + data_filling: the method of filling data. + audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. + """ + with torch.no_grad(): + if len(audio_data) > max_len: + if data_truncating == "rand_trunc": + longer = torch.tensor([True]) + elif data_truncating == "fusion": + # fusion + mel = get_mel(audio_data, audio_cfg) + # split to three parts + chunk_frames = ( + max_len // audio_cfg["hop_size"] + 1 + ) # the +1 related to how the spectrogram is computed + total_frames = mel.shape[0] + if chunk_frames == total_frames: + # there is a corner case where the audio length is + # larger than max_len but smaller than max_len+hop_size. + # In this case, we just use the whole audio. + mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([False]) + else: + ranges = np.array_split( + list(range(0, total_frames - chunk_frames + 1)), 3 + ) + # print('total_frames-chunk_frames:', total_frames-chunk_frames, + # 'len(audio_data):', len(audio_data), + # 'chunk_frames:', chunk_frames, + # 'total_frames:', total_frames) + if len(ranges[1]) == 0: + # if the audio is too short, we just use the first chunk + ranges[1] = [0] + if len(ranges[2]) == 0: + # if the audio is too short, we just use the first chunk + ranges[2] = [0] + # randomly choose index for each part + idx_front = np.random.choice(ranges[0]) + idx_middle = np.random.choice(ranges[1]) + idx_back = np.random.choice(ranges[2]) + # select mel + mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :] + mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :] + mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :] + + # shrink the mel + mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])( + mel[None] + )[0] + # logging.info(f"mel_shrink.shape: {mel_shrink.shape}") + + # stack + mel_fusion = torch.stack( + [mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink], + dim=0, + ) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([True]) + else: + raise NotImplementedError( + f"data_truncating {data_truncating} not implemented" + ) + # random crop to max_len (for compatibility) + overflow = len(audio_data) - max_len + idx = np.random.randint(0, overflow + 1) + audio_data = audio_data[idx : idx + max_len] + + else: # padding if too short + if len(audio_data) < max_len: # do nothing if equal + if data_filling == "repeatpad": + n_repeat = int(max_len / len(audio_data)) + audio_data = audio_data.repeat(n_repeat) + # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0) + # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0] + audio_data = F.pad( + audio_data, + (0, max_len - len(audio_data)), + mode="constant", + value=0, + ) + elif data_filling == "pad": + audio_data = F.pad( + audio_data, + (0, max_len - len(audio_data)), + mode="constant", + value=0, + ) + elif data_filling == "repeat": + n_repeat = int(max_len / len(audio_data)) + audio_data = audio_data.repeat(n_repeat + 1)[:max_len] + else: + raise NotImplementedError( + f"data_filling {data_filling} not implemented" + ) + if data_truncating == "fusion": + mel = get_mel(audio_data, audio_cfg) + mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([False]) + + sample["longer"] = longer + sample["waveform"] = audio_data + + return sample + + +def preprocess( + sample, + audio_ext, + text_ext, + max_len, + audio_cfg, + class_index_dict=None, + data_filling="pad", + data_truncating="rand_trunc", + text_augment_selection=None, +): + """ + Preprocess a single sample for wdsdataloader. + """ + audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) + audio_data = int16_to_float32(float32_to_int16(audio_data)) + audio_data = torch.tensor(audio_data).float() + + # TODO: (yusong) to be include in the future + # # if torchaudio not installed, use soundfile to load audio + # if torchaudio is None: + # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) + # audio_data = torch.tensor(audio_data).float() + # else: + # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py + # with tempfile.TemporaryDirectory() as dirname: + # os.makedirs(dirname, exist_ok=True) + # fname = os.path.join(dirname, f"file.flac") + # with open(fname, "wb") as stream: + # stream.write(sample[audio_ext]) + # audio_data, orig_sr = torchaudio.load(fname) + # audio_data = audio_data[0, :].float() + + sample = get_audio_features( + sample, audio_data, max_len, data_truncating, data_filling, audio_cfg + ) + del sample[audio_ext] + + try: + json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) + except: + print("sample[__url__]:", sample["__url__"]) + + # For selecting augmented text from dataset + if text_augment_selection is None or text_augment_selection == "none": + texts = json_dict_raw["text"] + elif text_augment_selection == "all": + if "text_augment_all" in json_dict_raw.keys(): + texts = json_dict_raw["text_augment_all"] + else: + texts = json_dict_raw["text"] + elif text_augment_selection == "augment_only": + if "text_augment_all" in json_dict_raw.keys(): + if json_dict_raw["text_augment_t5"] is None: + texts = json_dict_raw["text"] + else: + texts = json_dict_raw["text_augment_t5"] + else: + texts = json_dict_raw["text"] + else: + raise NotImplementedError( + f"text_augment_selection {text_augment_selection} not implemented" + ) + sample["full_text"] = texts + + if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: + texts = random.choice(texts) + sample["raw_text"] = texts + sample["text"] = tokenizer(texts) # text shape: [num_token] + if class_index_dict is not None: + # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing + # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array + # key, val = class_index_dict + # key = key[:].split('\n') + # _dict = {k: v for k, v in zip(key, val)} + sample["class_label"] = np.zeros(len(class_index_dict.keys())) + for x in json_dict_raw["tag"]: + sample["class_label"][class_index_dict[x]] = 1 + sample["class_label"] = torch.tensor(sample["class_label"]).float() + del sample[text_ext] + sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext + sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext + sample["audio_orig_sr"] = orig_sr + return sample + + +def collate_fn(batch): + """ + Collate function for wdsdataloader. + batch: a list of dict, each dict is a sample + """ + # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend. + batch_dict = {} + for k in batch[0].keys(): + if isinstance(batch[0][k], dict): # dealwith bert tokenizer output + batch_dict[k] = {} + for kk in batch[0][k].keys(): + tmp = [] + for i in range(len(batch)): + tmp.append(batch[i][k][kk]) + batch_dict[k][kk] = torch.vstack(tmp) + elif isinstance(batch[0][k], torch.Tensor): + batch_dict[k] = torch.stack([sample[k] for sample in batch]) + elif isinstance(batch[0][k], np.ndarray): + batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch])) + else: + batch_dict[k] = [sample[k] for sample in batch] + return batch_dict + + +def get_wds_dataset( + args, + model_cfg, + is_train, + audio_ext="flac", + text_ext="json", + max_len=480000, + proportion=1.0, + sizefilepath_=None, + is_local=None, +): + """ + Get a dataset for wdsdataloader. + """ + if is_local is None and (not args.remotedata is None): + is_local = not args.remotedata + + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + + if not sizefilepath_ is None: + sizefilepath = sizefilepath_ + else: + sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json") + + if proportion != 1.0: + num_samples, num_shards, input_shards, _ = sample_prop( + sizefilepath, input_shards, proportion, is_local=is_local + ) + else: + num_samples, num_shards = get_dataset_size( + input_shards, sizefilepath_=sizefilepath_, is_local=is_local + ) + + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + "Currently, number of dataset samples must be specified for training dataset. " + "Please specify via `--train-num-samples` if no dataset length info present." + ) + else: + num_samples = ( + args.val_num_samples or 0 + ) # eval will just exhaust the iterator if not specified + + pipeline = [wds.SimpleShardList(input_shards)] + # at this point we have an iterator over all the shards + # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node + if is_train or args.parallel_eval: + pipeline.extend( + [ + wds.detshuffle( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + ), + wds.split_by_node, + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker at each node + wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + rng=random.Random(args.seed), + ), + # wds.repeatedly, # FIXME determine if this is beneficial + ] + ) + else: + pipeline.extend( + [ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ] + ) + pipeline.append( + wds.map( + partial( + preprocess, + audio_ext=audio_ext, + text_ext=text_ext, + max_len=max_len, + audio_cfg=model_cfg["audio_cfg"], + class_index_dict=copy.deepcopy(args.class_index_dict), + data_filling=args.data_filling, + data_truncating=args.data_truncating, + text_augment_selection=args.text_augment_selection, + ) + ), + ) + + pipeline.append( + wds.batched( + args.batch_size, + partial=not (is_train or args.parallel_eval), + collation_fn=collate_fn, + ) + ) + + dataset = wds.DataPipeline(*pipeline) + if is_train or args.parallel_eval: + # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples. + # (yusong): See comments below. + # roll over and repeat a few samples to get same number of full batches on each node + global_batch_size = args.batch_size * args.world_size + num_batches = math.ceil(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = math.ceil( + num_batches / num_workers + ) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch( + num_worker_batches + ) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + kwargs = {} + if args.horovod: # multi-node training on summit + kwargs["multiprocessing_context"] = "forkserver" + + dataloader = wds.WebLoader( + dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader, None) + + +def wds_batch_list2dict( + batch, + keys=[ + "__url__", + "__key__", + "waveform", + "text", + "raw_text", + "audio_name", + "text_name", + "audio_orig_sr", + ], +): + """ + Return a dictionary of the batch, with keys as the names of the fields. + """ + assert len(keys) == len( + batch + ), "batch must have same number of keys as keys argument" + return {keys[i]: batch[i] for i in range(len(batch))} + + +def get_csv_dataset(args, preprocess_fn, is_train): + input_filename = args.train_data if is_train else args.val_data + assert input_filename + dataset = CsvDataset( + input_filename, + preprocess_fn, + img_key=args.csv_img_key, + caption_key=args.csv_caption_key, + sep=args.csv_separator, + ) + num_samples = len(dataset) + sampler = DistributedSampler(dataset) if args.distributed and is_train else None + shuffle = is_train and sampler is None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_toy_dataset(args, model_cfg, is_train): + index_path = args.train_data if is_train else args.val_data + ipc_path = args.train_ipc if is_train else args.val_ipc + assert index_path and ipc_path + eval_mode = not is_train + dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode) + + num_samples = len(dataset) + sampler = ( + DistributedSampler(dataset, shuffle=False) + if args.distributed and is_train + else None + ) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_dataset_fn(data_path, dataset_type): + if dataset_type == "webdataset": + return get_wds_dataset + elif dataset_type == "csv": + return get_csv_dataset + elif dataset_type == "auto": + ext = data_path.split(".")[-1] + if ext in ["csv", "tsv"]: + return get_csv_dataset + elif ext in ["tar"]: + return get_wds_dataset + else: + raise ValueError( + f"Tried to figure out dataset type, but failed for extention {ext}." + ) + elif dataset_type == "toy": + return get_toy_dataset + else: + raise ValueError(f"Unsupported dataset type: {dataset_type}") + + +def get_data(args, model_cfg): + data = {} + + args.class_index_dict = load_class_label(args.class_label_path) + + if args.datasetinfos is None: + args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] + if args.dataset_type == "webdataset": + args.train_data = get_tar_path_from_dataset_name( + args.datasetnames, + args.datasetinfos, + islocal=not args.remotedata, + proportion=args.dataset_proportion, + dataset_path=args.datasetpath, + full_dataset=args.full_train_dataset, + ) + + if args.full_train_dataset is None: + args.full_train_dataset = [] + if args.exclude_eval_dataset is None: + args.exclude_eval_dataset = [] + excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset + + val_dataset_names = ( + [n for n in args.datasetnames if n not in excluded_eval_datasets] + if excluded_eval_datasets + else args.datasetnames + ) + args.val_dataset_names = val_dataset_names + args.val_data = get_tar_path_from_dataset_name( + val_dataset_names, + ["valid", "test", "eval"], + islocal=not args.remotedata, + proportion=1, + dataset_path=args.datasetpath, + full_dataset=None, + ) + + if args.train_data: + data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( + args, model_cfg, is_train=True + ) + + if args.val_data: + data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( + args, model_cfg, is_train=False + ) + + return data diff --git a/audioldm/clap/training/distributed.py b/audioldm/clap/training/distributed.py new file mode 100755 index 0000000000000000000000000000000000000000..2fa61f76c5cc3ab9f6a9643042afa8e1f2e1cb7f --- /dev/null +++ b/audioldm/clap/training/distributed.py @@ -0,0 +1,150 @@ +import os + +import torch +import socket + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def is_global_master(args): + return args.rank == 0 + + +def is_local_master(args): + return args.local_rank == 0 + + +def is_master(args, local=False): + return is_local_master(args) if local else is_global_master(args) + + +def is_using_horovod(): + # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set + # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... + ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] + pmi_vars = ["PMI_RANK", "PMI_SIZE"] + if all([var in os.environ for var in ompi_vars]) or all( + [var in os.environ for var in pmi_vars] + ): + return True + else: + return False + + +def is_using_distributed(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) > 1 + if "SLURM_NTASKS" in os.environ: + return int(os.environ["SLURM_NTASKS"]) > 1 + return False + + +def world_info_from_env(): + local_rank = 0 + for v in ( + "SLURM_LOCALID", + "MPI_LOCALRANKID", + "OMPI_COMM_WORLD_LOCAL_RANK", + "LOCAL_RANK", + ): + if v in os.environ: + local_rank = int(os.environ[v]) + break + global_rank = 0 + for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"): + if v in os.environ: + global_rank = int(os.environ[v]) + break + world_size = 1 + for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"): + if v in os.environ: + world_size = int(os.environ[v]) + break + + return local_rank, global_rank, world_size + + +def init_distributed_device(args): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + args.distributed = False + args.world_size = 1 + args.rank = 0 # global rank + args.local_rank = 0 + if args.horovod: + assert hvd is not None, "Horovod is not installed" + hvd.init() + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.local_rank = local_rank + args.rank = world_rank + args.world_size = world_size + # args.local_rank = int(hvd.local_rank()) + # args.rank = hvd.rank() + # args.world_size = hvd.size() + args.distributed = True + os.environ["LOCAL_RANK"] = str(args.local_rank) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + print( + f"Distributed training: local_rank={args.local_rank}, " + f"rank={args.rank}, world_size={args.world_size}, " + f"hostname={socket.gethostname()}, pid={os.getpid()}" + ) + elif is_using_distributed(): + if "SLURM_PROCID" in os.environ: + # DDP via SLURM + args.local_rank, args.rank, args.world_size = world_info_from_env() + # SLURM var -> torch.distributed vars in case needed + os.environ["LOCAL_RANK"] = str(args.local_rank) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.local_rank = local_rank + args.rank = world_rank + args.world_size = world_size + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + else: + # DDP via torchrun, torch.distributed.launch + args.local_rank, _, _ = world_info_from_env() + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url + ) + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + args.distributed = True + print( + f"Distributed training: local_rank={args.local_rank}, " + f"rank={args.rank}, world_size={args.world_size}, " + f"hostname={socket.gethostname()}, pid={os.getpid()}" + ) + + if torch.cuda.is_available(): + if args.distributed and not args.no_set_device_rank: + device = "cuda:%d" % args.local_rank + else: + device = "cuda:0" + torch.cuda.set_device(device) + else: + device = "cpu" + args.device = device + device = torch.device(device) + return device diff --git a/audioldm/clap/training/imagenet_zeroshot_data.py b/audioldm/clap/training/imagenet_zeroshot_data.py new file mode 100755 index 0000000000000000000000000000000000000000..d32e55328d6799ccb8d61625f43abb80a33d6c17 --- /dev/null +++ b/audioldm/clap/training/imagenet_zeroshot_data.py @@ -0,0 +1,1088 @@ +# NOTE: This script is currently not supported for CLAP. + +imagenet_classnames = [ + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead shark", + "electric ray", + "stingray", + "rooster", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "American robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "American dipper", + "kite (bird of prey)", + "bald eagle", + "vulture", + "great grey owl", + "fire salamander", + "smooth newt", + "newt", + "spotted salamander", + "axolotl", + "American bullfrog", + "tree frog", + "tailed frog", + "loggerhead sea turtle", + "leatherback sea turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "green iguana", + "Carolina anole", + "desert grassland whiptail lizard", + "agama", + "frilled-necked lizard", + "alligator lizard", + "Gila monster", + "European green lizard", + "chameleon", + "Komodo dragon", + "Nile crocodile", + "American alligator", + "triceratops", + "worm snake", + "ring-necked snake", + "eastern hog-nosed snake", + "smooth green snake", + "kingsnake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "African rock python", + "Indian cobra", + "green mamba", + "sea snake", + "Saharan horned viper", + "eastern diamondback rattlesnake", + "sidewinder rattlesnake", + "trilobite", + "harvestman", + "scorpion", + "yellow garden spider", + "barn spider", + "European garden spider", + "southern black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie grouse", + "peafowl", + "quail", + "partridge", + "african grey parrot", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "duck", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "red king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "great egret", + "bittern bird", + "crane bird", + "limpkin", + "common gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "dunlin", + "common redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese Chin", + "Maltese", + "Pekingese", + "Shih Tzu", + "King Charles Spaniel", + "Papillon", + "toy terrier", + "Rhodesian Ridgeback", + "Afghan Hound", + "Basset Hound", + "Beagle", + "Bloodhound", + "Bluetick Coonhound", + "Black and Tan Coonhound", + "Treeing Walker Coonhound", + "English foxhound", + "Redbone Coonhound", + "borzoi", + "Irish Wolfhound", + "Italian Greyhound", + "Whippet", + "Ibizan Hound", + "Norwegian Elkhound", + "Otterhound", + "Saluki", + "Scottish Deerhound", + "Weimaraner", + "Staffordshire Bull Terrier", + "American Staffordshire Terrier", + "Bedlington Terrier", + "Border Terrier", + "Kerry Blue Terrier", + "Irish Terrier", + "Norfolk Terrier", + "Norwich Terrier", + "Yorkshire Terrier", + "Wire Fox Terrier", + "Lakeland Terrier", + "Sealyham Terrier", + "Airedale Terrier", + "Cairn Terrier", + "Australian Terrier", + "Dandie Dinmont Terrier", + "Boston Terrier", + "Miniature Schnauzer", + "Giant Schnauzer", + "Standard Schnauzer", + "Scottish Terrier", + "Tibetan Terrier", + "Australian Silky Terrier", + "Soft-coated Wheaten Terrier", + "West Highland White Terrier", + "Lhasa Apso", + "Flat-Coated Retriever", + "Curly-coated Retriever", + "Golden Retriever", + "Labrador Retriever", + "Chesapeake Bay Retriever", + "German Shorthaired Pointer", + "Vizsla", + "English Setter", + "Irish Setter", + "Gordon Setter", + "Brittany dog", + "Clumber Spaniel", + "English Springer Spaniel", + "Welsh Springer Spaniel", + "Cocker Spaniel", + "Sussex Spaniel", + "Irish Water Spaniel", + "Kuvasz", + "Schipperke", + "Groenendael dog", + "Malinois", + "Briard", + "Australian Kelpie", + "Komondor", + "Old English Sheepdog", + "Shetland Sheepdog", + "collie", + "Border Collie", + "Bouvier des Flandres dog", + "Rottweiler", + "German Shepherd Dog", + "Dobermann", + "Miniature Pinscher", + "Greater Swiss Mountain Dog", + "Bernese Mountain Dog", + "Appenzeller Sennenhund", + "Entlebucher Sennenhund", + "Boxer", + "Bullmastiff", + "Tibetan Mastiff", + "French Bulldog", + "Great Dane", + "St. Bernard", + "husky", + "Alaskan Malamute", + "Siberian Husky", + "Dalmatian", + "Affenpinscher", + "Basenji", + "pug", + "Leonberger", + "Newfoundland dog", + "Great Pyrenees dog", + "Samoyed", + "Pomeranian", + "Chow Chow", + "Keeshond", + "brussels griffon", + "Pembroke Welsh Corgi", + "Cardigan Welsh Corgi", + "Toy Poodle", + "Miniature Poodle", + "Standard Poodle", + "Mexican hairless dog (xoloitzcuintli)", + "grey wolf", + "Alaskan tundra wolf", + "red wolf or maned wolf", + "coyote", + "dingo", + "dhole", + "African wild dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian Mau", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "polar bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "longhorn beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket insect", + "stick insect", + "cockroach", + "praying mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "red admiral butterfly", + "ringlet butterfly", + "monarch butterfly", + "small white butterfly", + "sulphur butterfly", + "gossamer-winged butterfly", + "starfish", + "sea urchin", + "sea cucumber", + "cottontail rabbit", + "hare", + "Angora rabbit", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "common sorrel horse", + "zebra", + "pig", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram (adult male sheep)", + "bighorn sheep", + "Alpine ibex", + "hartebeest", + "impala (antelope)", + "gazelle", + "arabian camel", + "llama", + "weasel", + "mink", + "European polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas monkey", + "baboon", + "macaque", + "langur", + "black-and-white colobus", + "proboscis monkey", + "marmoset", + "white-headed capuchin", + "howler monkey", + "titi monkey", + "Geoffroy's spider monkey", + "common squirrel monkey", + "ring-tailed lemur", + "indri", + "Asian elephant", + "African bush elephant", + "red panda", + "giant panda", + "snoek fish", + "eel", + "silver salmon", + "rock beauty fish", + "clownfish", + "sturgeon", + "gar fish", + "lionfish", + "pufferfish", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibious vehicle", + "analog clock", + "apiary", + "apron", + "trash can", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint pen", + "Band-Aid", + "banjo", + "baluster / handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "swimming cap", + "bath towel", + "bathtub", + "station wagon", + "lighthouse", + "beaker", + "military hat (bearskin or shako)", + "beer bottle", + "beer glass", + "bell tower", + "baby bib", + "tandem bicycle", + "bikini", + "ring binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsleigh", + "bolo tie", + "poke bonnet", + "bookcase", + "bookstore", + "bottle cap", + "hunting bow", + "bow tie", + "brass memorial plaque", + "bra", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "high-speed train", + "butcher shop", + "taxicab", + "cauldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "tool kit", + "cardboard box / carton", + "car wheel", + "automated teller machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "mobile phone", + "chain", + "chain-link fence", + "chain mail", + "chainsaw", + "storage chest", + "chiffonier", + "bell or wind chime", + "china cabinet", + "Christmas stocking", + "church", + "movie theater", + "cleaver", + "cliff dwelling", + "cloak", + "clogs", + "cocktail shaker", + "coffee mug", + "coffeemaker", + "spiral or coil", + "combination lock", + "computer keyboard", + "candy store", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "construction crane", + "crash helmet", + "crate", + "infant bed", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "rotary dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishcloth", + "dishwasher", + "disc brake", + "dock", + "dog sled", + "dome", + "doormat", + "drilling rig", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso machine", + "face powder", + "feather boa", + "filing cabinet", + "fireboat", + "fire truck", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster bed", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gas mask or respirator", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golf cart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "radiator grille", + "grocery store", + "guillotine", + "hair clip", + "hair spray", + "half-track", + "hammer", + "hamper", + "hair dryer", + "hand-held computer", + "handkerchief", + "hard disk drive", + "harmonica", + "harp", + "combine harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoop skirt", + "gymnastic horizontal bar", + "horse-drawn vehicle", + "hourglass", + "iPod", + "clothes iron", + "carved pumpkin", + "jeans", + "jeep", + "T-shirt", + "jigsaw puzzle", + "rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop computer", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "ocean liner", + "lipstick", + "slip-on shoe", + "lotion", + "music speaker", + "loupe magnifying glass", + "sawmill", + "magnetic compass", + "messenger bag", + "mailbox", + "tights", + "one-piece bathing suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine cabinet", + "megalith", + "microphone", + "microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "ford model t", + "modem", + "monastery", + "monitor", + "moped", + "mortar and pestle", + "graduation cap", + "mosque", + "mosquito net", + "vespa", + "mountain bike", + "tent", + "computer mouse", + "mousetrap", + "moving van", + "muzzle", + "metal nail", + "neck brace", + "necklace", + "baby pacifier", + "notebook computer", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "pipe organ", + "oscilloscope", + "overskirt", + "bullock cart", + "oxygen mask", + "product packet / packaging", + "paddle", + "paddle wheel", + "padlock", + "paintbrush", + "pajamas", + "palace", + "pan flute", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "railroad car", + "patio", + "payphone", + "pedestal", + "pencil case", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "plectrum", + "Pickelhaube", + "picket fence", + "pickup truck", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate ship", + "drink pitcher", + "block plane", + "planetarium", + "plastic bag", + "plate rack", + "farm plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "soda bottle", + "plant pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "missile", + "projector", + "hockey puck", + "punching bag", + "purse", + "quill", + "quilt", + "race car", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "fishing casting reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "eraser", + "rugby ball", + "ruler measuring stick", + "sneaker", + "safe", + "safety pin", + "salt shaker", + "sandal", + "sarong", + "saxophone", + "scabbard", + "weighing scale", + "school bus", + "schooner", + "scoreboard", + "CRT monitor", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe store", + "shoji screen / room divider", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "balaclava ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot machine", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar thermal collector", + "sombrero", + "soup bowl", + "keyboard space bar", + "space heater", + "space shuttle", + "spatula", + "motorboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "through arch bridge", + "steel drum", + "stethoscope", + "scarf", + "stone wall", + "stopwatch", + "stove", + "strainer", + "tram", + "stretcher", + "couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglasses", + "sunglasses", + "sunscreen", + "suspension bridge", + "mop", + "sweatshirt", + "swim trunks / shorts", + "swing", + "electrical switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy bear", + "television", + "tennis ball", + "thatched roof", + "front curtain", + "thimble", + "threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toy store", + "tractor", + "semi-trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "hot tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright piano", + "vacuum cleaner", + "vase", + "vaulted or arched ceiling", + "velvet fabric", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "military aircraft", + "sink", + "washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "hair wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "airplane wing", + "wok", + "wooden spoon", + "wool", + "split-rail fence", + "shipwreck", + "sailboat", + "yurt", + "website", + "comic book", + "crossword", + "traffic or street sign", + "traffic light", + "dust jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "popsicle", + "baguette", + "bagel", + "pretzel", + "cheeseburger", + "hot dog", + "mashed potatoes", + "cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith apple", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "cherimoya (custard apple)", + "pomegranate", + "hay", + "carbonara", + "chocolate syrup", + "dough", + "meatloaf", + "pizza", + "pot pie", + "burrito", + "red wine", + "espresso", + "tea cup", + "eggnog", + "mountain", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeshore", + "promontory", + "sandbar", + "beach", + "valley", + "volcano", + "baseball player", + "bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "rose hip", + "horse chestnut seed", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn mushroom", + "earth star fungus", + "hen of the woods mushroom", + "bolete", + "corn cob", + "toilet paper", +] + + +openai_imagenet_template = [ + lambda c: f"a bad photo of a {c}.", + lambda c: f"a photo of many {c}.", + lambda c: f"a sculpture of a {c}.", + lambda c: f"a photo of the hard to see {c}.", + lambda c: f"a low resolution photo of the {c}.", + lambda c: f"a rendering of a {c}.", + lambda c: f"graffiti of a {c}.", + lambda c: f"a bad photo of the {c}.", + lambda c: f"a cropped photo of the {c}.", + lambda c: f"a tattoo of a {c}.", + lambda c: f"the embroidered {c}.", + lambda c: f"a photo of a hard to see {c}.", + lambda c: f"a bright photo of a {c}.", + lambda c: f"a photo of a clean {c}.", + lambda c: f"a photo of a dirty {c}.", + lambda c: f"a dark photo of the {c}.", + lambda c: f"a drawing of a {c}.", + lambda c: f"a photo of my {c}.", + lambda c: f"the plastic {c}.", + lambda c: f"a photo of the cool {c}.", + lambda c: f"a close-up photo of a {c}.", + lambda c: f"a black and white photo of the {c}.", + lambda c: f"a painting of the {c}.", + lambda c: f"a painting of a {c}.", + lambda c: f"a pixelated photo of the {c}.", + lambda c: f"a sculpture of the {c}.", + lambda c: f"a bright photo of the {c}.", + lambda c: f"a cropped photo of a {c}.", + lambda c: f"a plastic {c}.", + lambda c: f"a photo of the dirty {c}.", + lambda c: f"a jpeg corrupted photo of a {c}.", + lambda c: f"a blurry photo of the {c}.", + lambda c: f"a photo of the {c}.", + lambda c: f"a good photo of the {c}.", + lambda c: f"a rendering of the {c}.", + lambda c: f"a {c} in a video game.", + lambda c: f"a photo of one {c}.", + lambda c: f"a doodle of a {c}.", + lambda c: f"a close-up photo of the {c}.", + lambda c: f"a photo of a {c}.", + lambda c: f"the origami {c}.", + lambda c: f"the {c} in a video game.", + lambda c: f"a sketch of a {c}.", + lambda c: f"a doodle of the {c}.", + lambda c: f"a origami {c}.", + lambda c: f"a low resolution photo of a {c}.", + lambda c: f"the toy {c}.", + lambda c: f"a rendition of the {c}.", + lambda c: f"a photo of the clean {c}.", + lambda c: f"a photo of a large {c}.", + lambda c: f"a rendition of a {c}.", + lambda c: f"a photo of a nice {c}.", + lambda c: f"a photo of a weird {c}.", + lambda c: f"a blurry photo of a {c}.", + lambda c: f"a cartoon {c}.", + lambda c: f"art of a {c}.", + lambda c: f"a sketch of the {c}.", + lambda c: f"a embroidered {c}.", + lambda c: f"a pixelated photo of a {c}.", + lambda c: f"itap of the {c}.", + lambda c: f"a jpeg corrupted photo of the {c}.", + lambda c: f"a good photo of a {c}.", + lambda c: f"a plushie {c}.", + lambda c: f"a photo of the nice {c}.", + lambda c: f"a photo of the small {c}.", + lambda c: f"a photo of the weird {c}.", + lambda c: f"the cartoon {c}.", + lambda c: f"art of the {c}.", + lambda c: f"a drawing of the {c}.", + lambda c: f"a photo of the large {c}.", + lambda c: f"a black and white photo of a {c}.", + lambda c: f"the plushie {c}.", + lambda c: f"a dark photo of a {c}.", + lambda c: f"itap of a {c}.", + lambda c: f"graffiti of the {c}.", + lambda c: f"a toy {c}.", + lambda c: f"itap of my {c}.", + lambda c: f"a photo of a cool {c}.", + lambda c: f"a photo of a small {c}.", + lambda c: f"a tattoo of the {c}.", +] diff --git a/audioldm/clap/training/infer_demo.py b/audioldm/clap/training/infer_demo.py new file mode 100755 index 0000000000000000000000000000000000000000..7d1f4784898dbfeb69affefb6f624711adc8cb42 --- /dev/null +++ b/audioldm/clap/training/infer_demo.py @@ -0,0 +1,105 @@ +import sys + +import os +import torch +import librosa +from open_clip import create_model +from training.data import get_audio_features +from training.data import int16_to_float32, float32_to_int16 +from transformers import RobertaTokenizer + +tokenize = RobertaTokenizer.from_pretrained("roberta-base") + + +def tokenizer(text): + result = tokenize( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + + +PRETRAINED_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/checkpoints/epoch_top_0_audioset_no_fusion.pt" +WAVE_48k_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/audio/machine.wav" + + +def infer_text(): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + precision = "fp32" + amodel = "HTSAT-tiny" # or 'PANN-14' + tmodel = "roberta" # the best text encoder in our training + enable_fusion = False # False if you do not want to use the fusion model + fusion_type = "aff_2d" + pretrained = PRETRAINED_PATH + + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision=precision, + device=device, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + # load the text, can be a list (i.e. batch size) + text_data = ["I love the contrastive learning", "I love the pretrain model"] + # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90 + text_data = tokenizer(text_data) + + text_embed = model.get_text_embedding(text_data) + print(text_embed.size()) + + +def infer_audio(): + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + precision = "fp32" + amodel = "HTSAT-tiny" # or 'PANN-14' + tmodel = "roberta" # the best text encoder in our training + enable_fusion = False # False if you do not want to use the fusion model + fusion_type = "aff_2d" + pretrained = PRETRAINED_PATH + + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision=precision, + device=device, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + + # load the waveform of the shape (T,), should resample to 48000 + audio_waveform, sr = librosa.load(WAVE_48k_PATH, sr=48000) + # quantize + audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) + audio_waveform = torch.from_numpy(audio_waveform).float() + audio_dict = {} + + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + import ipdb + + ipdb.set_trace() + audio_dict = get_audio_features( + audio_dict, + audio_waveform, + 480000, + data_truncating="fusion", + data_filling="repeatpad", + audio_cfg=model_cfg["audio_cfg"], + ) + # can send a list to the model, to process many audio tracks in one time (i.e. batch size) + audio_embed = model.get_audio_embedding([audio_dict]) + print(audio_embed.size()) + import ipdb + + ipdb.set_trace() + + +if __name__ == "__main__": + infer_text() + infer_audio() diff --git a/audioldm/clap/training/logger.py b/audioldm/clap/training/logger.py new file mode 100755 index 0000000000000000000000000000000000000000..ac4634970fae6aacde2b7b808355dbd50c90ce73 --- /dev/null +++ b/audioldm/clap/training/logger.py @@ -0,0 +1,30 @@ +import logging + + +def setup_logging(log_file, level, include_host=False): + if include_host: + import socket + + hostname = socket.gethostname() + formatter = logging.Formatter( + f"%(asctime)s | {hostname} | %(levelname)s | %(message)s", + datefmt="%Y-%m-%d,%H:%M:%S", + ) + else: + formatter = logging.Formatter( + "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S" + ) + + logging.root.setLevel(level) + loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] + for logger in loggers: + logger.setLevel(level) + + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logging.root.addHandler(stream_handler) + + if log_file: + file_handler = logging.FileHandler(filename=log_file) + file_handler.setFormatter(formatter) + logging.root.addHandler(file_handler) diff --git a/audioldm/clap/training/lp_main.py b/audioldm/clap/training/lp_main.py new file mode 100755 index 0000000000000000000000000000000000000000..c2d4e8c85aaa3c8e4221963ef56a815cc14f354f --- /dev/null +++ b/audioldm/clap/training/lp_main.py @@ -0,0 +1,670 @@ +from cmath import cos +from inspect import getargs +import logging +import os +import random +from datetime import datetime +import bisect +import copy +from sched import scheduler +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch import optim +from torch.cuda.amp import GradScaler +import faulthandler +import pathlib +import argparse +import time + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from open_clip import create_model_and_transforms, trace_model, create_model +from training.data import get_data +from training.params import parse_args +from training.distributed import is_master, init_distributed_device, world_info_from_env +from training.logger import setup_logging +from training.scheduler import cosine_lr +from training.lp_train import train_one_epoch, evaluate +from open_clip.utils import get_tar_path_from_dataset_name, dataset_split, get_optimizer +from open_clip.utils import load_p, load_class_label +from open_clip.linear_probe import LinearProbe + + +def maintain_ckpts(args, startidx, all_idx_len): + for i in reversed(range(startidx, all_idx_len)): + if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): + os.rename( + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), + ) + if os.path.exists( + os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") + ): + os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) + return + + +def update_top_k_performance( + new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True +): + """ + Record the top-k performance of the current epoch. + current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} + """ + if isinstance(new_metrics_inputs, (list, tuple)): + new_metrics_inputs = np.mean(new_metrics_inputs) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, dict): + new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, (float, int)): + update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} + sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) + sorted_values = sorted( + current_top_k_ckpt_metrics.values(), reverse=bignumbetter + ) + sorted_values_ = copy.deepcopy(sorted_values) + sorted_values.append(new_metrics_inputs) + sorted_values = sorted(sorted_values, reverse=bignumbetter) + sorted_values = sorted_values[:-1] + + if sorted_values == sorted_values_: + return current_top_k_ckpt_metrics, new_metrics_inputs + else: + for i in range(len(sorted_keys)): + if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: + current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] + update_flag[sorted_keys[i]] = True + for i in range(len(update_flag)): + if update_flag[i]: + maintain_ckpts(args, i, len(sorted_keys)) + torch.save( + ckpt, + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + ) + break + return current_top_k_ckpt_metrics, new_metrics_inputs + + +# def updateifNone(a, b): +# a = b if None else a +# return a + + +def is_pretrained_params(n): + return ( + n.startswith("clap_model.transformer") + or n in ["clap_model.positional_embedding", "clap_model.text_projection"] + or n.startswith("clap_model.token_embedding") + or n.startswith("clap_model.ln_final") + or n.startswith("clap_model.logit_scale_t") + ) + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def config_lp_optimizer(model, data, args): + # set wd-related params to 0 if use adam optimizer + if args.optimizer == "adam": + args.wd = 0 + args.wd_pretrained = 0 + args.wd_new = 0 + + in_clap = lambda n, p: n.startswith("clap_model") + + named_parameters = list(model.named_parameters()) + + optimizer = {} + scheduler = {} + + # freeze text encoder + text_freeze_parameters = [ + p + for n, p in named_parameters + if n.startswith("clap_model.transformer") + or n in ["clap_model.positional_embedding", "clap_model.text_projection"] + or n.startswith("clap_model.token_embedding") + or n.startswith("clap_model.ln_final") + ] + + if args.freeze_text: + logging.info("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + + if not args.lp_freeze: + exclude = ( + lambda n, p: p.ndim < 2 + or "bn" in n + or "ln" in n + or "bias" in n + or "logit_scale" in n + ) + include = lambda n, p: not exclude(n, p) + + # (yusong): we do not split the learning rate anymore + # p for n, p in named_parameters if in_clap(n,p) and exclude(n, p) and p.requires_grad + gain_or_bias_params = [ + p for n, p in named_parameters if exclude(n, p) and p.requires_grad + ] + # rest_params = [p for n, p in named_parameters if in_clap(n,p) and include(n, p) and p.requires_grad] + rest_params = [ + p for n, p in named_parameters if include(n, p) and p.requires_grad + ] + + if args.train_data is None: + optimizer = None + scheduler = None + else: + total_steps = data["train"].dataloader.num_batches * args.epochs + + if args.split_opt: + for x in ["lr", "beta1", "beta2", "eps", "wd"]: + for y in ["_new", "_pretrained"]: + if getattr(args, x + y) is None: + setattr(args, x + y, getattr(args, x)) + + gain_or_bias_pretrained_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + rest_pretrained_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + gain_or_bias_new_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) + and (not is_pretrained_params(n)) + ] + rest_new_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) + and (not is_pretrained_params(n)) + ] + + pretrained_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, + { + "params": rest_pretrained_params, + "weight_decay": args.wd_pretrained, + }, + ], + lr=args.lr_pretrained, + betas=(args.beta1_pretrained, args.beta2_pretrained), + eps=args.eps_pretrained, + momentum=args.momentum_pretrained, + optimizer_name=args.optimizer, + ) + pretrained_params_scheduler = cosine_lr( + pretrained_params_optimizer, + args.lr_pretrained, + args.warmup, + total_steps, + ) + + new_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_new_params, "weight_decay": 0.0}, + {"params": rest_new_params, "weight_decay": args.wd_new}, + ], + lr=args.lr_new, + betas=(args.beta1_new, args.beta2_new), + eps=args.eps_new, + momentum=args.momentum_new, + optimizer_name=args.optimizer, + ) + new_params_scheduler = cosine_lr( + new_params_optimizer, args.lr_new, args.warmup, total_steps + ) + + optimizer["text"] = pretrained_params_optimizer + optimizer["audio"] = new_params_optimizer + scheduler["text"] = pretrained_params_scheduler + scheduler["audio"] = new_params_scheduler + + if args.horovod: + pretrained_params_optimizer = hvd.DistributedOptimizer( + pretrained_params_optimizer, + named_parameters=model.named_parameters(), + ) + new_params_optimizer = hvd.DistributedOptimizer( + new_params_optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state( + pretrained_params_optimizer, root_rank=0 + ) + hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) + else: + + optimizer["clap"] = get_optimizer( + [ + {"params": gain_or_bias_params, "weight_decay": 0.0}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=args.momentum, + optimizer_name=args.optimizer, + ) + scheduler["clap"] = cosine_lr( + optimizer["clap"], args.lr, args.warmup, total_steps + ) + + if args.horovod: + optimizer["clap"] = hvd.DistributedOptimizer( + optimizer["clap"], named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer["clap"], root_rank=0) + + # linear probe optimizer + else: + lp_params = [ + p for n, p in named_parameters if (not in_clap(n, p)) and p.requires_grad + ] + lp_optim = get_optimizer( + lp_params, + lr=args.lp_lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=0.9, + optimizer_name=args.optimizer, + ) + optimizer["lp"] = lp_optim + + return optimizer, scheduler, text_freeze_parameters + + +def main(): + args = parse_args() + + time.sleep(args.sleep) + + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + args.amodel = args.amodel.replace("/", "-") + # download sizes.json file + + # (yusong): the below two lines are for debug + # print("setting up faulthandler") + # faulthandler.register(10) + + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + args.class_index_dict = load_class_label(args.class_label_path) + + # get the name of the experiments + if args.name is None: + args.name = "-".join( + [ + datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), + f"linear_probe" f"model_{args.amodel}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ] + ) + + # discover initial world args early so we can log properly + args.distributed = False + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.remotedata and is_master(args): + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + args.log_path = None + if is_master(args, local=args.log_local): + log_base_path = os.path.join(args.logs, args.name) + os.makedirs(log_base_path, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path, log_filename) + + # avoid log dir in same name: + postfix = 0 + while os.path.exists(args.log_path): + postfix += 1 + log_base_path_new = log_base_path + "-" + str(postfix) + os.makedirs(log_base_path_new, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path_new, log_filename) + # print( + # "Error. Experiment already exists. Use --name {} to specify a new experiment." + # ) + # return -1 + + # Set logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # fully initialize distributed device environment + device = init_distributed_device(args) + + args.wandb = "wandb" in args.report_to or "all" in args.report_to + args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to + if is_master(args): + args.tensorboard_path = ( + os.path.join(args.logs, args.name, "tensorboard") + if args.tensorboard + else "" + ) + args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = "" + args.checkpoint_path = "" + + if args.copy_codebase: + copy_codebase(args) + + assert args.precision in ["amp", "fp16", "fp32"] + if args.precision == "fp16": + logging.warning( + "It is recommended to use AMP mixed-precision instead of FP16. " + "FP16 support needs further verification and tuning, especially for train." + ) + + if args.horovod: + logging.info( + f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + elif args.distributed: + logging.info( + f"Running in distributed mode with multiple processes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + else: + logging.info(f"Running with a single process. Device {args.device}.") + + logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}") + + # Create CLAP model + clap_model, clap_model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False, + pretrained_audio=args.pretrained_audio, + pretrained_text=args.pretrained_text, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type, + ) + + args.lp_out_ch = len(list(args.class_index_dict.keys())) + # Linear Probe + logging.info(f"linear probe using mlp: {args.lp_mlp}") + logging.info(f"linear probe using freeze: {args.lp_freeze}") + logging.info(f"linear probe act layer: {args.lp_act}") + logging.info(f"linear probe out ch: {args.lp_out_ch}") + logging.info(f"linear probe learning rate (if applicable): {args.lp_lr}") + logging.info(f"linear probe loss func: {args.lp_loss}") + logging.info(f"linear probe lp_metrics: {args.lp_metrics}") + + model = LinearProbe( + clap_model, + mlp=args.lp_mlp, + freeze=args.lp_freeze, + in_ch=512, + out_ch=args.lp_out_ch, + act=args.lp_act, + ) # in_ch is fixed (i.e., 512) + model = model.to(device) + + if args.horovod: + with torch.no_grad(): + for param in model.parameters(): + param.set_(param.contiguous()) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if is_master(args): + logging.info("Linear Probe CLAP Model:") + logging.info(f"{str(clap_model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args["static_graph"] = True + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True, **ddp_args + ) + + data = get_data(args, clap_model_cfg) + assert len(data), "At least one train or eval dataset must be specified." + if args.trace: + assert "train" not in data, "Cannot train with traced model" + + optimizer, scheduler, text_freeze_parameters = config_lp_optimizer( + model, data, args + ) + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module.") :]: v for k, v in sd.items()} + model.load_state_dict(sd) + if args.split_opt: + if optimizer is not None: + for k, o_ in optimizer.items(): + o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and "scaler" in checkpoint: + scaler.load_state_dict(checkpoint["scaler"]) + logging.info( + f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info( + f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" + ) + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + cudnn.deterministic = False + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, "Please install wandb." + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + project="clap", + notes=args.wandb_notes, + name=args.wandb_notes, + tags=[], + config=vars(args), + ) + if args.debug: + wandb.watch(model, log="all") + wandb.save(params_file) + logging.debug("Finished loading wandb.") + + if "train" not in data: + evaluate(model, data, start_epoch, args, writer) + return + elif start_epoch == 0 and "val" in data and not args.no_eval: + evaluate(model, data, 0, args, writer) + if args.save_top_performance: + current_top_k_ckpt_metrics = { + i: 0 for i in range(args.save_top_performance) + } # initialize the top-k metric for ckpts to 0 + + for epoch in range(start_epoch, args.epochs): + # freeze the text param after (include) args.freeze_text_after, this is -1 by default + if epoch == args.freeze_text_after: + print("Text pretrained parameters are freezed since this epoch.") + for k in text_freeze_parameters: + k.requires_grad = False + if is_master(args): + logging.info(f"Start epoch {epoch}") + + train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + completed_epoch = epoch + 1 + + if ( + any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) + and not args.no_eval + ): + metrics = evaluate(model, data, completed_epoch, args, writer) + if args.save_top_performance: + top_k_dataset = args.top_k_checkpoint_select_dataset + top_k_metric = args.top_k_checkpoint_select_metric + filtered_metrics = [ + v + for k, v in metrics.items() + if top_k_metric in k and top_k_dataset in k + ] # check all R@10 metrics (all dataset) and use it to update the ckpt + # Saving checkpoints. + if args.save_logs: + opt_dict = { + k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() + } + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + } + checkpoint_dict.update(opt_dict) + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.save_most_recent: + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_latest.pt"), + ) + if args.save_top_performance and not args.no_eval: + update_top_k_performance( + filtered_metrics, + current_top_k_ckpt_metrics, + args, + checkpoint_dict, + bignumbetter=True, + ) + + if args.wandb and is_master(args): + wandb.finish() + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree( + current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") + ) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main() diff --git a/audioldm/clap/training/lp_train.py b/audioldm/clap/training/lp_train.py new file mode 100755 index 0000000000000000000000000000000000000000..24a19bacd0a4b789415cfccbce1f8bc99bc493ed --- /dev/null +++ b/audioldm/clap/training/lp_train.py @@ -0,0 +1,301 @@ +import json +import logging +import math +import os +import time +from contextlib import suppress + +import numpy as np +import torch +import torch.nn.functional as F + +try: + import wandb +except ImportError: + wandb = None + +from open_clip import LPLoss, LPMetrics, lp_gather_features +from open_clip.utils import do_mixup, get_mix_lambda +from .distributed import is_master +from .zero_shot import zero_shot_eval + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def unwrap_model(model): + if hasattr(model, "module"): + return model.module + else: + return model + + +def train_one_epoch( + model, + data, + epoch, + optimizer, + scaler, + scheduler, + args, + tb_writer=None, + extra_suffix="", +): + device = torch.device(args.device) + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + model.train() + loss = LPLoss(args.lp_loss) + + dataloader, sampler = data["train"].dataloader, data["train"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + num_batches_per_epoch = dataloader.num_batches + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + # for toy dataset + if args.dataset_type == "toy": + dataloader.dataset.generate_queue() + + loss_m = AverageMeter() + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for i, batch in enumerate(dataloader): + step = num_batches_per_epoch * epoch + i + + if isinstance(scheduler, dict): + for s in scheduler.values(): + s(step) + else: + scheduler(step) + + audio = batch # contains mel_spec, wavform, and longer list + class_label = batch["class_label"] + # audio = audio.to(device=device, non_blocking=True) + class_label = class_label.to(device=device, non_blocking=True) + + if args.mixup: + # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146 + mix_lambda = torch.from_numpy( + get_mix_lambda(0.5, len(audio["waveform"])) + ).to(device) + class_label = do_mixup(class_label, mix_lambda) + else: + mix_lambda = None + + data_time_m.update(time.time() - end) + if isinstance(optimizer, dict): + for o_ in optimizer.values(): + o_.zero_grad() + else: + optimizer.zero_grad() + + with autocast(): + pred = model(audio, mix_lambda=mix_lambda, device=device) + total_loss = loss(pred, class_label) + + if isinstance(optimizer, dict): + if scaler is not None: + scaler.scale(total_loss).backward() + for o_ in optimizer.values(): + if args.horovod: + o_.synchronize() + scaler.unscale_(o_) + with o_.skip_synchronize(): + scaler.step(o_) + else: + scaler.step(o_) + scaler.update() + else: + total_loss.backward() + for o_ in optimizer.values(): + o_.step() + else: + if scaler is not None: + scaler.scale(total_loss).backward() + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + optimizer.step() + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100)) + unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i + 1 + + if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): + if isinstance(audio, dict): + batch_size = len(audio["waveform"]) + else: + batch_size = len(audio) + num_samples = batch_count * batch_size * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + loss_m.update(total_loss.item(), batch_size) + if isinstance(optimizer, dict): + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "lr": optimizer.param_groups[0]["lr"], + } + for name, val in log_data.items(): + name = f"train{extra_suffix}/{name}" + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, "Please install wandb." + wandb.log({name: val, "step": step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + + +def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): + metrics = {} + if not args.parallel_eval: + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + # CHANGE + # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + # metrics.update(zero_shot_metrics) + if is_master(args): + print("Evaluating...") + metric_names = args.lp_metrics.split(",") + eval_tool = LPMetrics(metric_names=metric_names) + + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + if "val" in data and ( + args.val_frequency + and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) + ): + if args.parallel_eval: + dataloader, sampler = data["val"].dataloader, data["val"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + samples_per_val = dataloader.num_samples + else: + dataloader = data["val"].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + eval_info = {"pred": [], "target": []} + with torch.no_grad(): + for i, batch in enumerate(dataloader): + audio = batch # contains mel_spec, wavform, and longer list + class_label = batch["class_label"] + + # audio = audio.to(device=device, non_blocking=True) + class_label = class_label.to(device=device, non_blocking=True) + + with autocast(): + pred = model(audio, device=device) + if args.parallel_eval: + pred, class_label = lp_gather_features( + pred, class_label, args.world_size, args.horovod + ) + eval_info["pred"].append(pred) + eval_info["target"].append(class_label) + + num_samples += class_label.shape[0] + + if (i % 100) == 0: # and i != 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" + ) + + if is_master(args): + eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu() + eval_info["target"] = torch.cat(eval_info["target"], 0).cpu() + metric_dict = eval_tool.evaluate_mertics( + eval_info["pred"], eval_info["target"] + ) + metrics.update(metric_dict) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + + if is_master(args): + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\n".join( + ["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics] + ) + ) + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, "Please install wandb." + for name, val in metrics.items(): + wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch}) + + return metrics + else: + return metrics diff --git a/audioldm/clap/training/main.py b/audioldm/clap/training/main.py new file mode 100755 index 0000000000000000000000000000000000000000..3b563a5d001be7adfbe779dee7ad8ac49aadc50d --- /dev/null +++ b/audioldm/clap/training/main.py @@ -0,0 +1,596 @@ +from inspect import getargs +import logging +import os +import random +from datetime import datetime +import bisect +import copy +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch import optim +from torch.cuda.amp import GradScaler +import faulthandler +import pathlib + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from open_clip import create_model_and_transforms, trace_model, create_model +from training.data import get_data +from training.distributed import is_master, init_distributed_device, world_info_from_env +from training.logger import setup_logging +from training.params import parse_args +from training.scheduler import cosine_lr +from training.train import train_one_epoch, evaluate +from open_clip.utils import dataset_split, get_optimizer + + +def maintain_ckpts(args, startidx, all_idx_len): + for i in reversed(range(startidx, all_idx_len)): + if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): + os.rename( + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), + ) + if os.path.exists( + os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") + ): + os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) + return + + +def update_top_k_performance( + new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True +): + """ + Record the top-k performance of the current epoch. + current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} + """ + if isinstance(new_metrics_inputs, (list, tuple)): + new_metrics_inputs = np.mean(new_metrics_inputs) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, dict): + new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, (float, int)): + update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} + sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) + sorted_values = sorted( + current_top_k_ckpt_metrics.values(), reverse=bignumbetter + ) + sorted_values_ = copy.deepcopy(sorted_values) + sorted_values.append(new_metrics_inputs) + sorted_values = sorted(sorted_values, reverse=bignumbetter) + sorted_values = sorted_values[:-1] + + if sorted_values == sorted_values_: + return current_top_k_ckpt_metrics, new_metrics_inputs + else: + for i in range(len(sorted_keys)): + if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: + current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] + update_flag[sorted_keys[i]] = True + for i in range(len(update_flag)): + if update_flag[i]: + maintain_ckpts(args, i, len(sorted_keys)) + torch.save( + ckpt, + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + ) + break + return current_top_k_ckpt_metrics, new_metrics_inputs + + +# def updateifNone(a, b): +# a = b if None else a +# return a + + +def is_pretrained_params(n): + return ( + n.startswith("transformer") + or n in ["positional_embedding", "text_projection"] + or n.startswith("token_embedding") + or n.startswith("ln_final") + or n.startswith("logit_scale_t") + ) + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def main(): + args = parse_args() + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + args.amodel = args.amodel.replace("/", "-") + # download sizes.json file + + # (yusong): the below two lines are for debug + # print("setting up faulthandler") + # faulthandler.register(10) + + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + if args.tmodel == "bert" or args.tmodel == "roberta" or args.tmodel == "bart": + assert ( + args.pretrained == "" or args.pretrained is None + ), "bert/roberta/bart text encoder does not support pretrained models." + + # get the name of the experiments + if args.name is None: + args.name = "-".join( + [ + datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), + f"model_{args.amodel}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ] + ) + + # discover initial world args early so we can log properly + args.distributed = False + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.remotedata and is_master(args): + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + args.log_path = None + if is_master(args, local=args.log_local): + log_base_path = os.path.join(args.logs, args.name) + os.makedirs(log_base_path, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path, log_filename) + if os.path.exists(args.log_path): + print( + "Error. Experiment already exists. Use --name {} to specify a new experiment." + ) + return -1 + + # Set logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # fully initialize distributed device environment + device = init_distributed_device(args) + + args.wandb = "wandb" in args.report_to or "all" in args.report_to + args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to + if is_master(args): + args.tensorboard_path = ( + os.path.join(args.logs, args.name, "tensorboard") + if args.tensorboard + else "" + ) + args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = "" + args.checkpoint_path = "" + + if args.copy_codebase: + copy_codebase(args) + + assert args.precision in ["amp", "fp16", "fp32"] + if args.precision == "fp16": + logging.warning( + "It is recommended to use AMP mixed-precision instead of FP16. " + "FP16 support needs further verification and tuning, especially for train." + ) + + if args.horovod: + logging.info( + f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + elif args.distributed: + logging.info( + f"Running in distributed mode with multiple processes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + else: + logging.info(f"Running with a single process. Device {args.device}.") + + logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}") + + model, model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=True, + pretrained_audio=args.pretrained_audio, + pretrained_text=args.pretrained_text, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type, + ) + + if args.horovod: + with torch.no_grad(): + for param in model.parameters(): + param.set_(param.contiguous()) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if is_master(args): + logging.info("Model:") + logging.info(f"{str(model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args["static_graph"] = True + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True, **ddp_args + ) + + data = get_data(args, model_cfg) + assert len(data), "At least one train or eval dataset must be specified." + if args.trace: + assert "train" not in data, "Cannot train with traced model" + + exclude = ( + lambda n, p: p.ndim < 2 + or "bn" in n + or "ln" in n + or "bias" in n + or "logit_scale" in n + ) + include = lambda n, p: not exclude(n, p) + + named_parameters = list(model.named_parameters()) + + # freeze text encoder + text_freeze_parameters = [p for n, p in named_parameters if "text_branch" in n] + + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + + gain_or_bias_params = [ + p for n, p in named_parameters if exclude(n, p) and p.requires_grad + ] + rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] + + # set wd-related params to 0 if use adam optimizer + if args.optimizer == "adam": + args.wd = 0 + args.wd_pretrained = 0 + args.wd_new = 0 + + if args.train_data is None: + optimizer = None + scheduler = None + else: + total_steps = data["train"].dataloader.num_batches * args.epochs + + if args.split_opt: + for x in ["lr", "beta1", "beta2", "eps", "wd"]: + for y in ["_new", "_pretrained"]: + if getattr(args, x + y) is None: + setattr(args, x + y, getattr(args, x)) + + gain_or_bias_pretrained_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + rest_pretrained_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + gain_or_bias_new_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n)) + ] + rest_new_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n)) + ] + pretrained_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, + { + "params": rest_pretrained_params, + "weight_decay": args.wd_pretrained, + }, + ], + lr=args.lr_pretrained, + betas=(args.beta1_pretrained, args.beta2_pretrained), + eps=args.eps_pretrained, + momentum=args.momentum_pretrained, + optimizer_name=args.optimizer, + ) + pretrained_params_scheduler = cosine_lr( + pretrained_params_optimizer, + args.lr_pretrained, + args.warmup, + total_steps, + ) + new_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_new_params, "weight_decay": 0.0}, + {"params": rest_new_params, "weight_decay": args.wd_new}, + ], + lr=args.lr_new, + betas=(args.beta1_new, args.beta2_new), + eps=args.eps_new, + momentum=args.momentum_new, + optimizer_name=args.optimizer, + ) + + new_params_scheduler = cosine_lr( + new_params_optimizer, args.lr_new, args.warmup, total_steps + ) + + optimizer = { + "pretrained": pretrained_params_optimizer, + "new": new_params_optimizer, + } + scheduler = { + "pretrained": pretrained_params_scheduler, + "new": new_params_scheduler, + } + + if args.horovod: + pretrained_params_optimizer = hvd.DistributedOptimizer( + pretrained_params_optimizer, + named_parameters=model.named_parameters(), + ) + new_params_optimizer = hvd.DistributedOptimizer( + new_params_optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0) + hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) + else: + optimizer = get_optimizer( + [ + {"params": gain_or_bias_params, "weight_decay": 0.0}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=args.momentum, + optimizer_name=args.optimizer, + ) + + scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) + + if args.horovod: + optimizer = hvd.DistributedOptimizer( + optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module.") :]: v for k, v in sd.items()} + model.load_state_dict(sd) + if args.split_opt: + if optimizer is not None: + for k, o_ in optimizer.items(): + o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and "scaler" in checkpoint: + scaler.load_state_dict(checkpoint["scaler"]) + logging.info( + f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info( + f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" + ) + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + cudnn.deterministic = False + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, "Please install wandb." + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + project="clap", + notes=args.wandb_notes, + name=args.wandb_notes, + tags=[], + config=vars(args), + ) + if args.debug: + wandb.watch(model, log="all") + wandb.save(params_file) + logging.debug("Finished loading wandb.") + + if "train" not in data: + evaluate(model, data, start_epoch, args, writer) + return + elif start_epoch == 0 and "val" in data and not args.no_eval: + evaluate(model, data, 0, args, writer) + # print(f'rank {args.rank}, Start First Evaluation')# (yusong): for debug + if args.save_top_performance: + current_top_k_ckpt_metrics = { + i: 0 for i in range(args.save_top_performance) + } # initialize the top-k metric for ckpts to 0 + + # print(f'rank {args.rank}, Start Training') # (yusong): for debug + for epoch in range(start_epoch, args.epochs): + # freeze the text param after (include) args.freeze_text_after, this is -1 by default + if epoch == args.freeze_text_after: + print("Text pretrained parameters are freezed since this epoch.") + for k in text_freeze_parameters: + k.requires_grad = False + if is_master(args): + logging.info(f"Start epoch {epoch}") + + train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + completed_epoch = epoch + 1 + + if ( + any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) + and not args.no_eval + ): + metrics = evaluate(model, data, completed_epoch, args, writer) + if args.save_top_performance: + top_k_dataset = args.top_k_checkpoint_select_dataset + top_k_metric = args.top_k_checkpoint_select_metric + filtered_metrics = [ + v + for k, v in metrics.items() + if top_k_metric in k and top_k_dataset in k + ] # check all R@10 metrics (all dataset) and use it to update the ckpt + # Saving checkpoints. + if args.save_logs: + if args.split_opt: + opt_dict = { + k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() + } + else: + opt_dict = {"optimizer": optimizer.state_dict()} + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + } + checkpoint_dict.update(opt_dict) + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.save_most_recent: + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_latest.pt"), + ) + if args.save_top_performance and not args.no_eval: + update_top_k_performance( + filtered_metrics, + current_top_k_ckpt_metrics, + args, + checkpoint_dict, + bignumbetter=True, + ) + + if args.wandb and is_master(args): + wandb.finish() + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree( + current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") + ) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main() diff --git a/audioldm/clap/training/params.py b/audioldm/clap/training/params.py new file mode 100755 index 0000000000000000000000000000000000000000..b1933e3a78ff583733846ea285d56eb0a0b892a5 --- /dev/null +++ b/audioldm/clap/training/params.py @@ -0,0 +1,569 @@ +import argparse +import os + +CACHE_DIR = os.getenv( + "AUDIOLDM_CACHE_DIR", + "~/.cache") + + + +def get_default_params(model_name): + # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) + model_name = model_name.lower() + if "vit" in model_name: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} + else: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--train-data", + type=str, + default=None, + help="Path to h5 filewith training data", + ) + parser.add_argument( + "--val-data", + type=str, + default=None, + help="Path to h5 file with validation data", + ) + parser.add_argument( + "--freeze-text", + default=False, + action="store_true", + help="if you need to freeze the text encoder, make this True", + ) + parser.add_argument( + "--freeze-text-after", + type=int, + default=-1, + help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it", + ) + parser.add_argument( + "--train-ipc", + type=str, + default=None, + help="Path to npy file of the number of instance per class in training data", + ) + parser.add_argument( + "--val-ipc", + type=str, + default=None, + help="Path to npy file of the number of instance per class in validation data", + ) + parser.add_argument( + "--train-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Required for webdataset if not available in info file.", + ) + parser.add_argument( + "--val-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Useful for webdataset if not available in info file.", + ) + parser.add_argument( + "--dataset-type", + choices=["webdataset", "csv", "auto", "toy"], + default="auto", + help="Which type of dataset to process.", + ) + parser.add_argument( + "--csv-separator", + type=str, + default="\t", + help="For csv-like datasets, which separator to use.", + ) + parser.add_argument( + "--csv-img-key", + type=str, + default="filepath", + help="For csv-like datasets, the name of the key for the image paths.", + ) + parser.add_argument( + "--csv-caption-key", + type=str, + default="title", + help="For csv-like datasets, the name of the key for the captions.", + ) + parser.add_argument( + "--imagenet-val", + type=str, + default=None, + help="Path to imagenet val set for conducting zero shot evaluation.", + ) + parser.add_argument( + "--imagenet-v2", + type=str, + default=None, + help="Path to imagenet v2 for conducting zero shot evaluation.", + ) + parser.add_argument( + "--datasetnames", + nargs="+", + default=None, + help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects", + ) + parser.add_argument( + "--full-train-dataset", + nargs="+", + default=None, + help="Which dataset will be trained with all the subsets. (train+test)", + ) + parser.add_argument( + "--exclude-eval-dataset", + nargs="+", + default=None, + help="Which dataset will be excluded with evaluation", + ) + parser.add_argument( + "--datasetinfos", + nargs="+", + default=None, + help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval", + ) + parser.add_argument( + "--dataset-proportion", + type=float, + default=1.0, + help="How much proportion of dataset we want to train.", + ) + parser.add_argument( + "--remotedata", + default=False, + action="store_true", + help="if the dataset is remote, set this flag", + ) + parser.add_argument( + "--class-label-path", + type=str, + default=None, + help="The path of the class label pickle or csv.", + ) + parser.add_argument( + "--datasetpath", + type=str, + default="/mnt/audio_clip/webdataset_tar", + help="The path to the dataset", + ) + parser.add_argument( + "--logs", + type=str, + default="./logs/", + help="Where to store tensorboard logs. Use None to avoid storing logs.", + ) + parser.add_argument( + "--log-local", + action="store_true", + default=False, + help="log files on local master, otherwise global master only.", + ) + parser.add_argument( + "--name", + type=str, + default=None, + help="Optional identifier for the experiment when storing logs. Otherwise use current time.", + ) + parser.add_argument( + "--workers", type=int, default=1, help="Number of workers per GPU." + ) + parser.add_argument( + "--batch-size", type=int, default=64, help="Batch size per GPU." + ) + parser.add_argument( + "--epochs", type=int, default=32, help="Number of epochs to train for." + ) + parser.add_argument("--lr", type=float, default=None, help="Learning rate.") + parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") + parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") + parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") + parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.") + parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") + + parser.add_argument( + "--split-opt", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--lr-pretrained", type=float, default=None, help="Learning rate for text." + ) + parser.add_argument( + "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text." + ) + parser.add_argument( + "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text." + ) + parser.add_argument( + "--eps-pretrained", type=float, default=None, help="Adam epsilon for text." + ) + parser.add_argument( + "--wd-pretrained", type=float, default=0.2, help="Weight decay for text." + ) + parser.add_argument( + "--momentum-pretrained", type=float, default=0.9, help="Momentum for text." + ) + parser.add_argument( + "--lr-new", type=float, default=None, help="Learning rate for audio." + ) + parser.add_argument( + "--beta1-new", type=float, default=None, help="Adam beta 1 for audio." + ) + parser.add_argument( + "--beta2-new", type=float, default=None, help="Adam beta 2 for audio." + ) + parser.add_argument( + "--eps-new", type=float, default=None, help="Adam epsilon for audio." + ) + parser.add_argument( + "--wd-new", type=float, default=0.2, help="Weight decay for audio." + ) + parser.add_argument( + "--momentum-new", type=float, default=0.9, help="Momentum for audio." + ) + parser.add_argument( + "--warmup", type=int, default=10000, help="Number of steps to warmup for." + ) + parser.add_argument( + "--use-bn-sync", + default=False, + action="store_true", + help="Whether to use batch norm sync.", + ) + parser.add_argument( + "--skip-scheduler", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--save-frequency", type=int, default=1, help="How often to save checkpoints." + ) + parser.add_argument( + "--save-top-performance", + type=int, + default=0, + help="Save the top x performance weights if the value >0", + ) + parser.add_argument( + "--save-most-recent", + action="store_true", + default=False, + help="Always save the most recent model trained to epoch_latest.pt.", + ) + parser.add_argument( + "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." + ) + parser.add_argument( + "--val-frequency", + type=int, + default=1, + help="How often to run evaluation with val data.", + ) + parser.add_argument( + "--resume", + default=None, + type=str, + help="path to latest checkpoint (default: none)", + ) + parser.add_argument( + "--precision", + choices=["amp", "fp16", "fp32"], + default="amp", + help="Floating point precision.", + ) + parser.add_argument( + "--amodel", + type=str, + default="RN50", + help="Name of the audio backbone to use.", + ) + parser.add_argument( + "--tmodel", + type=str, + default="transformer", + help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]", + ) + parser.add_argument( + "--pretrained-audio", + default="", + type=str, + help="Use a pretrained audio model weights for the audio encoder of CLAP", + ) + parser.add_argument( + "--pretrained-text", + default="", + type=str, + help="Use a pretrained text model weights for the text encoder of CLAP", + ) + parser.add_argument( + "--pretrained", + default="", + type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--pretrained-image", + default=False, + action="store_true", + help="Load imagenet pretrained weights for image tower backbone if available.", + ) + parser.add_argument( + "--lock-image", + default=False, + action="store_true", + help="Lock full image tower by disabling gradients.", + ) + parser.add_argument( + "--lock-image-unlocked-groups", + type=int, + default=0, + help="Leave last n image tower layer groups unlocked.", + ) + parser.add_argument( + "--lock-image-freeze-bn-stats", + default=False, + action="store_true", + help="Freeze BatchNorm running stats in image tower for any locked layers.", + ) + parser.add_argument( + "--local-loss", + default=False, + action="store_true", + help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)", + ) + parser.add_argument( + "--gather-with-grad", + default=False, + action="store_true", + help="enable full distributed gradient for feature gather", + ) + parser.add_argument( + "--force-quick-gelu", + default=False, + action="store_true", + help="Force use of QuickGELU activation for non-OpenAI transformer models.", + ) + parser.add_argument( + "--torchscript", + default=False, + action="store_true", + help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", + ) + parser.add_argument( + "--trace", + default=False, + action="store_true", + help="torch.jit.trace the model for inference / eval only", + ) + # arguments for distributed training + parser.add_argument( + "--dist-url", + default="env://", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" + ) + parser.add_argument( + "--report-to", + default="", + type=str, + help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']", + ) + parser.add_argument( + "--wandb-notes", default="", type=str, help="Notes if logging with wandb" + ) + parser.add_argument( + "--C", type=float, default=3.16, help="inverse regularizer for logistic reg." + ) + parser.add_argument( + "--debug", + default=False, + action="store_true", + help="If true, more information is logged.", + ) + parser.add_argument( + "--copy-codebase", + default=False, + action="store_true", + help="If true, we copy the entire base on the log diretory, and execute from there.", + ) + parser.add_argument( + "--horovod", + default=False, + action="store_true", + help="Use horovod for distributed training.", + ) + parser.add_argument( + "--ddp-static-graph", + default=False, + action="store_true", + help="Enable static graph optimization for DDP in PyTorch >= 1.11.", + ) + parser.add_argument( + "--no-set-device-rank", + default=False, + action="store_true", + help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", + ) + parser.add_argument("--seed", type=int, default=4242, help="Default random seed.") + + parser.add_argument( + "--top-k-checkpoint-select-dataset", + type=str, + default="all", + help="The dataset of selecting top-k checkpoint.", + ) + + # @R10, @R@5, @R1, mAP@10 + parser.add_argument( + "--top-k-checkpoint-select-metric", + type=str, + default="_R@10", + help="The metric for selecting top-k checkpoint.", + ) + parser.add_argument( + "--openai-model-cache-dir", + type=str, + default=f"{CACHE_DIR}/clip", + help="Directory to download OpenAI models.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="adamw", + help="can be AdamW or SGD", + ) + parser.add_argument( + "--parallel-eval", + default=False, + action="store_true", + help="Eval in parallel (multi-GPU, multi-node).", + ) + + parser.add_argument( + "--no-eval", + default=False, + action="store_true", + help="Training without evaluation.", + ) + + parser.add_argument( + "--lp-mlp", + default=False, + action="store_true", + help="Linear Probe using MLP layer or not.", + ) + + parser.add_argument( + "--lp-freeze", + default=False, + action="store_true", + help="Linear Probe using Freeze CLAP or not", + ) + + parser.add_argument( + "--lp-act", + default="None", + type=str, + help="Options are ['relu','elu','prelu','softmax','sigmoid']", + ) + + parser.add_argument( + "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe." + ) + + parser.add_argument( + "--lp-metrics", + type=str, + default="map,mauc,acc", + help="Metrics of Linear Probe.", + ) + + parser.add_argument( + "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe" + ) + parser.add_argument( + "--kappa", + type=float, + default=0, + help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss", + ) + + parser.add_argument( + "--data-filling", + type=str, + default="pad", + help="type of data filling when the audio length is shorter than the max length." + "Can be one of the following: repeat, repeatpad, pad", + ) + parser.add_argument( + "--data-truncating", + type=str, + default="rand_trunc", + help="type of data truncation when the audio length is longer than the max length." + "Can be one of the following: rand_trunc, fusion", + ) + + parser.add_argument( + "--clap-mlploss", + default=False, + action="store_true", + help="Using MLP loss for CLAP model or not", + ) + + parser.add_argument( + "--wandb-id", + type=str, + default=None, + help="the id of wandb experiment to restore.", + ) + + parser.add_argument( + "--sleep", type=float, default=0, help="sleep n seconds before start training" + ) + + # variable length processing + parser.add_argument( + "--enable-fusion", + default=False, + action="store_true", + help="Enable feature funsion for variable-length data", + ) + + parser.add_argument( + "--fusion-type", + type=str, + default="None", + help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']", + ) + + parser.add_argument( + "--mixup", + default=False, + action="store_true", + help="Enable mixup in finetuning training.", + ) + parser.add_argument( + "--text-augment-selection", + type=str, + default=None, + help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']", + ) + + args = parser.parse_args() + + # If some params are not passed, we use the default values based on model name. + default_params = get_default_params(args.amodel) + for name, val in default_params.items(): + if getattr(args, name) is None: + setattr(args, name, val) + + return args diff --git a/audioldm/clap/training/scheduler.py b/audioldm/clap/training/scheduler.py new file mode 100755 index 0000000000000000000000000000000000000000..7151ffbab25a113673b7627027b443b27f22cb0f --- /dev/null +++ b/audioldm/clap/training/scheduler.py @@ -0,0 +1,24 @@ +import numpy as np + + +def assign_learning_rate(optimizer, new_lr): + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + + +def _warmup_lr(base_lr, warmup_length, step): + return base_lr * (step + 1) / warmup_length + + +def cosine_lr(optimizer, base_lr, warmup_length, steps): + def _lr_adjuster(step): + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + e = step - warmup_length + es = steps - warmup_length + lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr + assign_learning_rate(optimizer, lr) + return lr + + return _lr_adjuster diff --git a/audioldm/clap/training/train.py b/audioldm/clap/training/train.py new file mode 100755 index 0000000000000000000000000000000000000000..f5759c4679d2ee9c0748444adf66b8453cf09728 --- /dev/null +++ b/audioldm/clap/training/train.py @@ -0,0 +1,838 @@ +import json +import logging +import math +import os +import time +from contextlib import suppress + +import numpy as np +import torch +import torch.nn.functional as F + +try: + import wandb +except ImportError: + wandb = None + +from open_clip import ClipLoss, gather_features +from .distributed import is_master +from .zero_shot import zero_shot_eval + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def unwrap_model(model): + if hasattr(model, "module"): + return model.module + else: + return model + + +def train_one_epoch( + model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None +): + device = torch.device(args.device) + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + model.train() + loss = ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + weight_loss_kappa=args.kappa, + ) + + dataloader, sampler = data["train"].dataloader, data["train"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + num_batches_per_epoch = dataloader.num_batches + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + # for toy dataset + if args.dataset_type == "toy": + dataloader.dataset.generate_queue() + + loss_m = AverageMeter() + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for i, batch in enumerate(dataloader): + # logging.info(f"batch {i} of {num_batches_per_epoch}") + step = num_batches_per_epoch * epoch + i + if isinstance(scheduler, dict): + for s in scheduler.values(): + s(step) + else: + scheduler(step) + audios = batch # contains mel_spec, wavform, and longer list + texts = batch["text"] + # audios = audios.to(device=device, non_blocking=True) + # texts = texts.to(device=device, non_blocking=True) + + data_time_m.update(time.time() - end) + if isinstance(optimizer, dict): + for o_ in optimizer.values(): + o_.zero_grad() + else: + optimizer.zero_grad() + + with autocast(): + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + logit_scale_a, + logit_scale_t, + ) = model(audios, texts, device) + + if args.clap_mlploss: + total_loss = loss( + audio_features=audio_features, + text_features=text_features, + logit_scale_a=logit_scale_a, + logit_scale_t=logit_scale_t, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + ) + else: + total_loss = loss( + audio_features=audio_features, + text_features=text_features, + logit_scale_a=logit_scale_a, + ) + if isinstance(optimizer, dict): + if scaler is not None: + scaler.scale(total_loss).backward() + for o_ in optimizer.values(): + if args.horovod: + o_.synchronize() + scaler.unscale_(o_) + with o_.skip_synchronize(): + scaler.step(o_) + else: + scaler.step(o_) + scaler.update() + else: + total_loss.backward() + for o_ in optimizer.values(): + o_.step() + else: + if scaler is not None: + scaler.scale(total_loss).backward() + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + optimizer.step() + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).logit_scale_a.clamp_(0, math.log(100)) + if args.clap_mlploss: + unwrap_model(model).logit_scale_t.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i + 1 + if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): + if isinstance(audios, dict): + batch_size = len(audios["waveform"]) + else: + batch_size = len(audios) + num_samples = batch_count * batch_size * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + loss_m.update(total_loss.item(), batch_size) + logit_scale_scalar_a = logit_scale_a.item() + logit_scale_scalar_t = logit_scale_t.item() + if isinstance(optimizer, dict): + if args.clap_mlploss: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + f"Logit Scale Text: {logit_scale_scalar_t:.3f}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "scale_text": logit_scale_scalar_t, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + + else: + if args.clap_mlploss: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + f"Logit Scale Text: {logit_scale_scalar_t:.3f}" + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "scale_text": logit_scale_scalar_t, + "lr": optimizer.param_groups[0]["lr"], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "lr": optimizer.param_groups[0]["lr"], + } + for name, val in log_data.items(): + name = "train/" + name + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, "Please install wandb." + wandb.log({name: val, "step": step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + + +def evaluate(model, data, epoch, args, tb_writer=None): + metrics = {} + if not args.parallel_eval: + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + # CHANGE + # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + # metrics.update(zero_shot_metrics) + if is_master(args): + print("Evaluating...") + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + if args.val_dataset_names == ["Clotho", "audiocaps"]: + # if only clotho and audiocaps are used, then we will use a different evaluation function. + # This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio. + if args.parallel_eval: + # (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps. + raise NotImplementedError( + "Parallel evaluation not supported for eval only Clotho and audiocaps." + ) + val_metrics_per_dataset = evaluate_clotho_audiocaps( + model, data, epoch, args, autocast, device, tb_writer + ) + for m in val_metrics_per_dataset.values(): + metrics.update(m) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + metrics = select_top_metric_clotho_audiocaps( + metrics, val_metrics_per_dataset, args + ) + elif "val" in data and ( + args.val_frequency + and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) + ): + dataloader = data["val"].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + # FIXME this does not scale past small eval datasets + # all_audio_features @ all_text_features will blow up memory and compute very quickly + eval_info = {} + if args.clap_mlploss: + eval_info["all"] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + "all_audio_features_mlp": [], + "all_text_features_mlp": [], + } # cumulative_loss = 0.0 + else: + eval_info["all"] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } # cumu + # all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], [] + with torch.no_grad(): + for i, batch in enumerate(dataloader): + audios = batch # contains mel_spec, wavform, and longer list + texts = batch["text"] + # audios = audios.to(device=device, non_blocking=True) + + all_names = list( + set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) + ) + for name in all_names: + if name not in eval_info.keys(): + if args.clap_mlploss: + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + "all_audio_features_mlp": [], + "all_text_features_mlp": [], + } + else: + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + with autocast(): + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + logit_scale_a, + logit_scale_t, + ) = model(audios, texts, device) + + if args.parallel_eval: + # multi-GPU eval + if args.clap_mlploss: + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + ) = gather_features( + audio_features=audio_features, + text_features=text_features, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + local_loss=False, + gather_with_grad=False, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + ) + else: + (audio_features, text_features,) = gather_features( + audio_features=audio_features, + text_features=text_features, + local_loss=False, + gather_with_grad=False, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + ) + + if is_master(args): + num_samples += audio_features.shape[0] + for n in [*all_names, "all"]: + if n == "all": + eval_info[n]["all_audio_features"].append( + audio_features.cpu() + ) + eval_info[n]["all_text_features"].append( + text_features.cpu() + ) + if args.clap_mlploss: + eval_info[n]["all_audio_features_mlp"].append( + audio_features_mlp.cpu() + ) + eval_info[n]["all_text_features_mlp"].append( + text_features_mlp.cpu() + ) + else: + idx = np.where( + np.array( + [ + "-".join(b.split("/")[-3:-1]) + for b in batch["__url__"] + ] + ) + == n + )[0] + eval_info[n]["all_audio_features"].append( + audio_features.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + eval_info[n]["all_text_features"].append( + text_features.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + if args.clap_mlploss: + eval_info[n]["all_audio_features_mlp"].append( + audio_features_mlp.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + eval_info[n]["all_text_features_mlp"].append( + text_features_mlp.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + # print(f'eval step {i}') # (yusong): for debug + + # cumulative_loss += total_loss * batch_size + # num_samples += batch_size + if is_master(args) and (i % 100) == 0: # and i != 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" + ) + if is_master(args): + val_metrics_per_dataset = {} + for n in eval_info.keys(): + if args.clap_mlploss: + metrics_single_dataset = get_metrics( + audio_features=torch.cat( + eval_info[n]["all_audio_features"] + ), + text_features=torch.cat(eval_info[n]["all_text_features"]), + logit_scale_a=logit_scale_a.cpu(), + audio_features_mlp=torch.cat( + eval_info[n]["all_audio_features_mlp"] + ), + text_features_mlp=torch.cat( + eval_info[n]["all_text_features_mlp"] + ), + logit_scale_t=logit_scale_t.cpu(), + mlp_loss=args.clap_mlploss, + ) + else: + metrics_single_dataset = get_metrics( + audio_features=torch.cat( + eval_info[n]["all_audio_features"] + ), + text_features=torch.cat(eval_info[n]["all_text_features"]), + logit_scale_a=logit_scale_a.cpu(), + mlp_loss=args.clap_mlploss, + ) + val_metrics_per_dataset[n] = { + n + "/" + k: v for k, v in metrics_single_dataset.items() + } + metrics.update(val_metrics_per_dataset[n]) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + if is_master(args): + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\n".join( + [ + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()]) + for m in val_metrics_per_dataset.values() + ] + ) + ) + + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, "Please install wandb." + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, "epoch": epoch}) + + return metrics + else: + return metrics + + +def get_metrics( + audio_features, + text_features, + logit_scale_a, + audio_features_mlp=None, + text_features_mlp=None, + logit_scale_t=None, + mlp_loss=False, +): + metrics = {} + if mlp_loss: + # Set up audio to text & text to audio similary matrice + a_logits_per_audio = ( + (logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu() + ) + a_logits_per_text = a_logits_per_audio.t().detach().cpu() + t_logits_per_audio = ( + (logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu() + ) + t_logits_per_text = t_logits_per_audio.t().detach().cpu() + + labels = torch.arange(audio_features.shape[0]).long() + # Change the loss from two terms into four terms with 2x2 combined CE loss + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels) + + F.cross_entropy(a_logits_per_text, labels) + + F.cross_entropy(t_logits_per_audio, labels) + + F.cross_entropy(t_logits_per_text, labels) + ) / 4 + + metrics[f"cumulative_loss"] = total_loss.item() + metrics[f"num_samples"] = audio_features.shape[0] + + logits = { + "audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2, + "text_to_audio": (a_logits_per_text + t_logits_per_text) / 2, + } + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + else: + # print("text_features", text_features) + # print("text_features.shape", text_features.shape) + logits_per_audio = ( + (logit_scale_a * audio_features @ text_features.t()).detach().cpu() + ) + logits_per_text = logits_per_audio.t().detach().cpu() + + labels = torch.arange(audio_features.shape[0]).long() + # Change the loss from two terms into four terms with 2x2 combined CE loss + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + metrics[f"cumulative_loss"] = total_loss.item() + metrics[f"num_samples"] = audio_features.shape[0] + + logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text} + + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + for name, logit in logits.items(): + ranking = torch.argsort(logit, descending=True) + preds = torch.where(ranking == ground_truth)[ + 1 + ] # (yusong) this line is slow because it uses single thread + preds = preds.detach().cpu().numpy() + metrics[f"{name}_mean_rank"] = preds.mean() + 1 + metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = np.mean(preds < k) + # map@10 + metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) + + return metrics + + +def evaluate_clotho_audiocaps( + model, data, epoch, args, autocast, device, tb_writer=None +): + """ + Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py. + 1. for text-to-audio retrieval, do 5 times and average the results + 2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text + 3. for map@10 in audio-to-text retrieval: + 3.1: sort the rank of 5 text + 3.2: exclude the rank >=10 (0-index) + 3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks). + (3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth. + (3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc. + """ + # TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now. + dataloader = data["val"].dataloader + with torch.no_grad(): + eval_info = {} + for i, batch in enumerate(dataloader): + audios = batch # contains mel_spec, wavform, and longer list + + # each item in the list has 5 texts + if args.tmodel == "transformer": + from open_clip import tokenize + + texts = [tokenize(t) for t in batch["full_text"]] + texts = torch.cat(texts) + else: + from .data import tokenizer + + texts = [ + tokenizer(t) for t in batch["full_text"] + ] # 5 texts for each audio + texts = { + k: torch.cat([t[k] for t in texts]) for k in texts[0].keys() + } # 5 x batch + + # audios = audios.to(device=device, non_blocking=True) + + all_names = list( + set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) + ) + for name in all_names: + if name not in eval_info.keys(): + # we will not use mlp outputs even if args.clap_mlploss=True + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + with autocast(): + audio_features = model(audios, None, device) + text_features = model(None, texts, device) + audio_features = F.normalize(audio_features, dim=-1) + text_features = F.normalize(text_features, dim=-1) + + all_names = list( + set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) + ) + for n in all_names: + idx = np.where( + np.array( + ["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]] + ) + == n + )[0] + eval_info[n]["all_audio_features"].append( + audio_features.cpu().index_select(0, torch.tensor(idx).long()) + ) + # (yusong) please double-check. This is for selecting 5 text features at once. + # because idx is a list of indices in size of num_samples, + # and text_features is a tensor of size (5*num_samples, dim) + # so we need to select 5 consecutive indices at once for a single index in idx. + eval_info[n]["all_text_features"].append( + text_features.cpu() + .reshape([-1, 5, text_features.shape[1]]) + .index_select(0, torch.tensor(idx).long()) + .reshape([-1, text_features.shape[1]]) + ) + + val_metrics_all = {} + + for n in eval_info.keys(): + logit_scale_a, logit_scale_t = model(None, None, device) + logit_scale_a = logit_scale_a.cpu() + + audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0) + text_features = torch.cat(eval_info[n]["all_text_features"], dim=0) + + logits_per_audio = ( + (logit_scale_a * audio_features @ text_features.t()).detach().cpu() + ) + logits_per_text = logits_per_audio.t().detach().cpu() + + # logits_per_audio shape: [num_samples, num_samples*5] + # logits_per_text shape: [num_samples*5, num_samples] + + logging.info( + f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, " + f"logits_per_text shape: {logits_per_text.shape}" + ) + + metrics = {} + num_samples = audio_features.shape[0] + metrics[f"num_samples"] = num_samples + + # (yusong) the following code is very important, please double-check: + # logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d] + # logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] + # Those two are retrieving one of the 5 text for each audio. + labels = torch.arange(audio_features.shape[0]).long() + audio_to_text_loss = [ + F.cross_entropy( + logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d], + labels, + ) + for d in range(5) + ] + text_to_audio_loss = [ + F.cross_entropy( + logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :], + labels, + ) + for d in range(5) + ] + total_loss = (np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)) / 2 + + metrics[f"cumulative_loss"] = total_loss.item() + + # text to audio: do 5 times + pred_text = [] + for d in range(5): + logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] + ground_truth = torch.arange(len(logit)).view(-1, 1) + ranking = torch.argsort( + logit, descending=True + ) # [num_samples, num_samples] + preds = torch.where(ranking == ground_truth)[1] + pred_text.append(preds.detach().cpu().numpy()) + pred_text_concat = np.concatenate(pred_text, axis=0) # [5*num_samples] + metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1 + metrics[f"text_to_audio_median_rank"] = ( + np.floor(np.median(pred_text_concat)) + 1 + ) + for k in [1, 5, 10]: + metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k) + # map@10 + metrics[f"text_to_audio_mAP@10"] = np.mean( + np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0) + ) + + # audio to text: take the best result + # for audio to text map 10, sort and assign descending ground truth. + # see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103 + # map@10 + map_all = [] + pred_audio_all = [] + for d in range(num_samples): + # logits_per_audio: [num_samples, num_samples*5] + logit_single = logits_per_audio[d, :] # [5*num_samples] + # Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4] + ranking = torch.argsort( + logit_single, descending=True + ) # [5*num_samples] + # ranking: the index of first match, second match, ... + ground_truth = torch.arange(d * 5, d * 5 + 5)[None] + all_pred = torch.where( + torch.stack([ranking] * 5) == ground_truth.view(-1, 1) + )[1] + min_pred = torch.min(all_pred) + pred_audio_all.append(min_pred.detach().cpu().numpy()) + all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy() + # /5 because we have 5 text, so it means for the text rank >=10 we count as 0. + map_single = ( + np.sum( + (np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1)) + ) + / 5 + ) + map_all.append(map_single) + metrics[f"audio_to_text_mAP@10"] = np.mean(map_all) + for k in [1, 5, 10]: + metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k) + + val_metrics_all[n] = {n + "/" + k: v for k, v in metrics.items()} + return val_metrics_all + + +def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset): + """ + Calculate performance for Clotho+AudioCaps for model selection. + """ + selection_performance_all = [] + for n in val_metrics_per_dataset.keys(): + selection_performance = ( + val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"] + + val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"] + ) / 2 + selection_performance_all.append(selection_performance) + return np.mean(selection_performance_all) + + +def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args): + # val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value + # metrics: dict, key: metric name, value: metric value + # Hack: use args to save the top performance + if not hasattr(args, "top_selection_performance"): + selection_performance = calculate_selection_performance_clotho_audiocaps( + val_metrics_per_dataset + ) + # TODO: write the if and else together + metric_update = {} + for n in val_metrics_per_dataset.keys(): + for k in val_metrics_per_dataset[n].keys(): + metric_update[ + k.split("/")[0] + "-top" + "/" + k.split("/")[1] + ] = val_metrics_per_dataset[n][k] + metric_update["top_selection_performance"] = selection_performance + metric_update["top-selection-epoch"] = metrics["epoch"] + metrics.update(metric_update) + args.top_metric = metric_update + args.top_selection_performance = selection_performance + else: + selection_performance_new = calculate_selection_performance_clotho_audiocaps( + val_metrics_per_dataset + ) + selection_performance_old = args.top_selection_performance + if selection_performance_new > selection_performance_old: + metric_update = {} + for n in val_metrics_per_dataset.keys(): + for k in val_metrics_per_dataset[n].keys(): + metric_update[ + k.split("/")[0] + "-top" + "/" + k.split("/")[1] + ] = val_metrics_per_dataset[n][k] + metric_update["top_selection_performance"] = selection_performance_new + metric_update["top-selection-epoch"] = metrics["epoch"] + metrics.update(metric_update) + args.top_metric = metric_update + args.top_selection_performance = selection_performance_new + else: + metrics.update(args.top_metric) + return metrics diff --git a/audioldm/clap/training/zero_shot.py b/audioldm/clap/training/zero_shot.py new file mode 100755 index 0000000000000000000000000000000000000000..28b8fccc1af17fc69002857a7f529ac041c374f2 --- /dev/null +++ b/audioldm/clap/training/zero_shot.py @@ -0,0 +1,95 @@ +# NOTE: This script is currently not supported for CLAP. +import logging +from contextlib import suppress + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from open_clip import tokenize +from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template + + +def zero_shot_classifier(model, classnames, templates, args): + with torch.no_grad(): + zeroshot_weights = [] + for classname in tqdm(classnames): + texts = [template(classname) for template in templates] # format with class + texts = tokenize(texts).to(args.device) # tokenize + if args.distributed and not args.horovod: + class_embeddings = model.module.encode_text(texts) + else: + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) + return zeroshot_weights + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [ + float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) + for k in topk + ] + + +def run(model, classifier, dataloader, args): + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + with torch.no_grad(): + top1, top5, n = 0.0, 0.0, 0.0 + for images, target in tqdm(dataloader, unit_scale=args.batch_size): + images = images.to(args.device) + target = target.to(args.device) + + with autocast(): + # predict + if args.distributed and not args.horovod: + image_features = model.module.encode_image(images) + else: + image_features = model.encode_image(images) + image_features = F.normalize(image_features, dim=-1) + logits = 100.0 * image_features @ classifier + + # measure accuracy + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = top1 / n + top5 = top5 / n + return top1, top5 + + +def zero_shot_eval(model, data, epoch, args): + if "imagenet-val" not in data and "imagenet-v2" not in data: + return {} + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + + logging.info("Starting zero-shot imagenet.") + + logging.info("Building zero-shot classifier") + classifier = zero_shot_classifier( + model, imagenet_classnames, openai_imagenet_template, args + ) + + logging.info("Using classifier") + results = {} + if "imagenet-val" in data: + top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args) + results["imagenet-zeroshot-val-top1"] = top1 + results["imagenet-zeroshot-val-top5"] = top5 + if "imagenet-v2" in data: + top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args) + results["imagenetv2-zeroshot-val-top1"] = top1 + results["imagenetv2-zeroshot-val-top5"] = top5 + + logging.info("Finished zero-shot imagenet.") + + return results diff --git a/audioldm/hifigan/__init__.py b/audioldm/hifigan/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e0ae476fe58c48e998c56234a55b871beba4042d --- /dev/null +++ b/audioldm/hifigan/__init__.py @@ -0,0 +1,7 @@ +from .models import Generator + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self diff --git a/audioldm/hifigan/models.py b/audioldm/hifigan/models.py new file mode 100755 index 0000000000000000000000000000000000000000..c4382cc39de0463f9b7c0f33f037dbc233e7cb36 --- /dev/null +++ b/audioldm/hifigan/models.py @@ -0,0 +1,174 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + # print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/audioldm/hifigan/utilities.py b/audioldm/hifigan/utilities.py new file mode 100755 index 0000000000000000000000000000000000000000..47fd39ea0af181772d640feec2413cf631a75702 --- /dev/null +++ b/audioldm/hifigan/utilities.py @@ -0,0 +1,85 @@ +import os +import json + +import torch +import numpy as np + +import audioldm.hifigan as hifigan + +HIFIGAN_16K_64 = { + "resblock": "1", + "num_gpus": 6, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "upsample_initial_channel": 1024, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "segment_size": 8192, + "num_mels": 64, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 160, + "win_size": 1024, + "sampling_rate": 16000, + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, +} + + +def get_available_checkpoint_keys(model, ckpt): + print("==> Attemp to reload from %s" % ckpt) + state_dict = torch.load(ckpt)["state_dict"] + current_state_dict = model.state_dict() + new_state_dict = {} + for k in state_dict.keys(): + if ( + k in current_state_dict.keys() + and current_state_dict[k].size() == state_dict[k].size() + ): + new_state_dict[k] = state_dict[k] + else: + print("==> WARNING: Skipping %s" % k) + print( + "%s out of %s keys are matched" + % (len(new_state_dict.keys()), len(state_dict.keys())) + ) + return new_state_dict + + +def get_param_num(model): + num_param = sum(param.numel() for param in model.parameters()) + return num_param + + +def get_vocoder(config, device): + config = hifigan.AttrDict(HIFIGAN_16K_64) + vocoder = hifigan.Generator(config) + vocoder.eval() + vocoder.remove_weight_norm() + vocoder.to(device) + return vocoder + + +def vocoder_infer(mels, vocoder, lengths=None): + with torch.no_grad(): + wavs = vocoder(mels).squeeze(1) + + wavs = (wavs.cpu().numpy() * 32768).astype("int16") + + if lengths is not None: + wavs = wavs[:, :lengths] + + return wavs diff --git a/audioldm/latent_diffusion/__init__.py b/audioldm/latent_diffusion/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm/latent_diffusion/attention.py b/audioldm/latent_diffusion/attention.py new file mode 100755 index 0000000000000000000000000000000000000000..27886f5ee3c7eb856100503b838399106ef00051 --- /dev/null +++ b/audioldm/latent_diffusion/attention.py @@ -0,0 +1,469 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn +from einops import rearrange + +from audioldm.latent_diffusion.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + """ + ### Cross Attention Layer + This falls-back to self-attention when conditional embeddings are not specified. + """ + + # use_flash_attention: bool = True + use_flash_attention: bool = False + + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + is_inplace: bool = True, + ): + # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True): + """ + :param d_model: is the input embedding size + :param n_heads: is the number of attention heads + :param d_head: is the size of a attention head + :param d_cond: is the size of the conditional embeddings + :param is_inplace: specifies whether to perform the attention softmax computation inplace to + save memory + """ + super().__init__() + + self.is_inplace = is_inplace + self.n_heads = heads + self.d_head = dim_head + + # Attention scaling factor + self.scale = dim_head**-0.5 + + # The normal self-attention layer + if context_dim is None: + context_dim = query_dim + + # Query, key and value mappings + d_attn = dim_head * heads + self.to_q = nn.Linear(query_dim, d_attn, bias=False) + self.to_k = nn.Linear(context_dim, d_attn, bias=False) + self.to_v = nn.Linear(context_dim, d_attn, bias=False) + + # Final linear layer + self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout)) + + # Setup [flash attention](https://github.com/HazyResearch/flash-attention). + # Flash attention is only used if it's installed + # and `CrossAttention.use_flash_attention` is set to `True`. + try: + # You can install flash attention by cloning their Github repo, + # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention) + # and then running `python setup.py install` + from flash_attn.flash_attention import FlashAttention + + self.flash = FlashAttention() + # Set the scale for scaled dot-product attention. + self.flash.softmax_scale = self.scale + # Set to `None` if it's not installed + except ImportError: + self.flash = None + + def forward(self, x, context=None, mask=None): + """ + :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` + :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` + """ + + # If `cond` is `None` we perform self attention + has_cond = context is not None + if not has_cond: + context = x + + # Get query, key and value vectors + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + + # Use flash attention if it's available and the head size is less than or equal to `128` + if ( + CrossAttention.use_flash_attention + and self.flash is not None + and not has_cond + and self.d_head <= 128 + ): + return self.flash_attention(q, k, v) + # Otherwise, fallback to normal attention + else: + return self.normal_attention(q, k, v) + + def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """ + #### Flash Attention + :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + """ + + # Get batch size and number of elements along sequence axis (`width * height`) + batch_size, seq_len, _ = q.shape + + # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of + # shape `[batch_size, seq_len, 3, n_heads * d_head]` + qkv = torch.stack((q, k, v), dim=2) + # Split the heads + qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head) + + # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to + # fit this size. + if self.d_head <= 32: + pad = 32 - self.d_head + elif self.d_head <= 64: + pad = 64 - self.d_head + elif self.d_head <= 128: + pad = 128 - self.d_head + else: + raise ValueError(f"Head size ${self.d_head} too large for Flash Attention") + + # Pad the heads + if pad: + qkv = torch.cat( + (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1 + ) + + # Compute attention + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ + # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]` + # TODO here I add the dtype changing + out, _ = self.flash(qkv.type(torch.float16)) + # Truncate the extra head size + out = out[:, :, :, : self.d_head].float() + # Reshape to `[batch_size, seq_len, n_heads * d_head]` + out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head) + + # Map to `[batch_size, height * width, d_model]` with a linear layer + return self.to_out(out) + + def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """ + #### Normal Attention + + :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + """ + + # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` + q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32] + k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32] + v = v.view(*v.shape[:2], self.n_heads, -1) + + # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ + attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale + + # Compute softmax + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ + if self.is_inplace: + half = attn.shape[0] // 2 + attn[half:] = attn[half:].softmax(dim=-1) + attn[:half] = attn[:half].softmax(dim=-1) + else: + attn = attn.softmax(dim=-1) + + # Compute attention output + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ + # attn: [bs, 20, 64, 1] + # v: [bs, 1, 20, 32] + out = torch.einsum("bhij,bjhd->bihd", attn, v) + # Reshape to `[batch_size, height * width, n_heads * d_head]` + out = out.reshape(*out.shape[:2], -1) + # Map to `[batch_size, height * width, d_model]` with a linear layer + return self.to_out(out) + + +# class CrossAttention(nn.Module): +# def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): +# super().__init__() +# inner_dim = dim_head * heads +# context_dim = default(context_dim, query_dim) + +# self.scale = dim_head ** -0.5 +# self.heads = heads + +# self.to_q = nn.Linear(query_dim, inner_dim, bias=False) +# self.to_k = nn.Linear(context_dim, inner_dim, bias=False) +# self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + +# self.to_out = nn.Sequential( +# nn.Linear(inner_dim, query_dim), +# nn.Dropout(dropout) +# ) + +# def forward(self, x, context=None, mask=None): +# h = self.heads + +# q = self.to_q(x) +# context = default(context, x) +# k = self.to_k(context) +# v = self.to_v(context) + +# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + +# sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + +# if exists(mask): +# mask = rearrange(mask, 'b ... -> b (...)') +# max_neg_value = -torch.finfo(sim.dtype).max +# mask = repeat(mask, 'b j -> (b h) () j', h=h) +# sim.masked_fill_(~mask, max_neg_value) + +# # attention, what we cannot get enough of +# attn = sim.softmax(dim=-1) + +# out = einsum('b i j, b j d -> b i d', attn, v) +# out = rearrange(out, '(b h) n d -> b n (h d)', h=h) +# return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + if context is None: + return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) + else: + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + no_context=False, + ): + super().__init__() + + if no_context: + context_dim = None + + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/audioldm/latent_diffusion/ddim.py b/audioldm/latent_diffusion/ddim.py new file mode 100755 index 0000000000000000000000000000000000000000..53c3490f5c894e2b80e09d50e7f0b367691ea2a9 --- /dev/null +++ b/audioldm/latent_diffusion/ddim.py @@ -0,0 +1,377 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from audioldm.latent_diffusion.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, + extract_into_tensor, +) + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = gr.Progress().tqdm(time_range, desc="DDIM Sampler", total=total_steps) + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO deterministic forward pass? + img = ( + img_orig * mask + (1.0 - mask) * img + ) # In the first sampling step, img is pure gaussian noise + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + img, pred_x0 = outs + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + ): + + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = gr.Progress().tqdm(time_range, desc="Decoding image", total=total_steps) + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) + x_dec = x_latent + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return x_dec + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + # When unconditional_guidance_scale == 1: only e_t + # When unconditional_guidance_scale == 0: only unconditional + # When unconditional_guidance_scale > 1: add more unconditional guidance + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise # TODO + return x_prev, pred_x0 diff --git a/audioldm/latent_diffusion/ddpm.py b/audioldm/latent_diffusion/ddpm.py new file mode 100755 index 0000000000000000000000000000000000000000..ff9e0d362744db534bd8f019fe9d647e72266649 --- /dev/null +++ b/audioldm/latent_diffusion/ddpm.py @@ -0,0 +1,442 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" +import sys +import os + +import torch +import torch.nn as nn +import numpy as np +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm + +from audioldm.utils import exists, default, count_params, instantiate_from_config +from audioldm.latent_diffusion.ema import LitEma +from audioldm.latent_diffusion.util import ( + make_beta_schedule, + extract_into_tensor, + noise_like, +) + +import soundfile as sf +import os + + +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DiffusionWrapper(nn.Module): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [ + None, + "concat", + "crossattn", + "hybrid", + "adm", + "film", + ] + + def forward( + self, x, t, c_concat: list = None, c_crossattn: list = None, c_film: list = None + ): + x = x.contiguous() + t = t.contiguous() + + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == "concat": + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == "crossattn": + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == "hybrid": + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif ( + self.conditioning_key == "film" + ): # The condition is assumed to be a global token, which wil pass through a linear layer and added with the time embedding for the FILM + cc = c_film[0].squeeze(1) # only has one token + out = self.diffusion_model(x, t, y=cc) + elif self.conditioning_key == "adm": + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class DDPM(nn.Module): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + latent_t_size=256, + latent_f_size=16, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0.0, + ): + super().__init__() + assert parameterization in [ + "eps", + "x0", + ], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + self.state = None + # print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + + self.latent_t_size = latent_t_size + self.latent_f_size = latent_f_size + + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + else: + self.logvar = nn.Parameter(self.logvar, requires_grad=False) + + self.logger_save_dir = None + self.logger_project = None + self.logger_version = None + self.label_indices_total = None + # To avoid the system cannot find metric value for checkpoint + self.metrics_buffer = { + "val/kullback_leibler_divergence_sigmoid": 15.0, + "val/kullback_leibler_divergence_softmax": 10.0, + "val/psnr": 0.0, + "val/ssim": 0.0, + "val/inception_score_mean": 1.0, + "val/inception_score_std": 0.0, + "val/kernel_inception_distance_mean": 0.0, + "val/kernel_inception_distance_std": 0.0, + "val/frechet_inception_distance": 133.0, + "val/frechet_audio_distance": 32.0, + } + self.initial_learning_rate = None + + def get_log_dir(self): + if ( + self.logger_save_dir is None + and self.logger_project is None + and self.logger_version is None + ): + return os.path.join( + self.logger.save_dir, self.logger._project, self.logger.version + ) + else: + return os.path.join( + self.logger_save_dir, self.logger_project, self.logger_version + ) + + def set_log_dir(self, save_dir, project, version): + self.logger_save_dir = save_dir + self.logger_project = project + self.logger_version = version + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + # print(f"{context}: Switched to EMA weights") + pass + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + # print(f"{context}: Restored training weights") + pass + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised + ) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = ( + (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() + ) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm( + reversed(range(0, self.num_timesteps)), + desc="Sampling t", + total=self.num_timesteps, + ): + img = self.p_sample( + img, + torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised, + ) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + shape = (batch_size, channels, self.latent_t_size, self.latent_f_size) + channels = self.channels + return self.p_sample_loop(shape, return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def forward(self, x, *args, **kwargs): + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch + fbank, log_magnitudes_stft, label_indices, fname, waveform, text = batch + ret = {} + + ret["fbank"] = ( + fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() + ) + ret["stft"] = log_magnitudes_stft.to( + memory_format=torch.contiguous_format + ).float() + # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float() + ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() + ret["text"] = list(text) + ret["fname"] = fname + + return ret[k] diff --git a/audioldm/latent_diffusion/ema.py b/audioldm/latent_diffusion/ema.py new file mode 100755 index 0000000000000000000000000000000000000000..880ca3d205d9b4d7450e146930a93f2e63c58b70 --- /dev/null +++ b/audioldm/latent_diffusion/ema.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/audioldm/latent_diffusion/openaimodel.py b/audioldm/latent_diffusion/openaimodel.py new file mode 100755 index 0000000000000000000000000000000000000000..831d7aafb36bba16888e4389153979a6c13639f5 --- /dev/null +++ b/audioldm/latent_diffusion/openaimodel.py @@ -0,0 +1,1069 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from audioldm.latent_diffusion.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from audioldm.latent_diffusion.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1).contiguous() # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1).contiguous() + qkv = self.qkv(self.norm(x)).contiguous() + h = self.attention(qkv).contiguous() + h = self.proj_out(h).contiguous() + return (x + h).reshape(b, c, *spatial).contiguous() + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = ( + qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1) + ) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum( + "bts,bcs->bct", + weight, + v.reshape(bs * self.n_heads, ch, length).contiguous(), + ) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + extra_film_condition_dim=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + extra_film_use_concat=False, # If true, concatenate extrafilm condition with time embedding, else addition + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.extra_film_condition_dim = extra_film_condition_dim + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + self.extra_film_use_concat = extra_film_use_concat + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + assert not ( + self.num_classes is not None and self.extra_film_condition_dim is not None + ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim." + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.use_extra_film_by_concat = ( + self.extra_film_condition_dim is not None and self.extra_film_use_concat + ) + self.use_extra_film_by_addition = ( + self.extra_film_condition_dim is not None and not self.extra_film_use_concat + ) + + if self.extra_film_condition_dim is not None: + self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim) + # print("+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " % self.extra_film_condition_dim) + # if(self.use_extra_film_by_concat): + # print("\t By concatenation with time embedding") + # elif(self.use_extra_film_by_concat): + # print("\t By addition with time embedding") + + if use_spatial_transformer and ( + self.use_extra_film_by_concat or self.use_extra_film_by_addition + ): + # print("+ Spatial transformer will only be used as self-attention. Because you have choose to use film as your global condition.") + spatial_transformer_no_context = True + else: + spatial_transformer_no_context = False + + if use_spatial_transformer and not spatial_transformer_no_context: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None and not spatial_transformer_no_context: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + from omegaconf.listconfig import ListConfig + + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + no_context=spatial_transformer_no_context, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + no_context=spatial_transformer_no_context, + ), + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + no_context=spatial_transformer_no_context, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + self.shape_reported = False + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional + :return: an [N x C x ...] Tensor of outputs. + """ + if not self.shape_reported: + # print("The shape of UNet input is", x.size()) + self.shape_reported = True + + assert (y is not None) == ( + self.num_classes is not None or self.extra_film_condition_dim is not None + ), "must specify y if and only if the model is class-conditional or film embedding conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + if self.use_extra_film_by_addition: + emb = emb + self.film_emb(y) + elif self.use_extra_film_by_concat: + emb = th.cat([emb, self.film_emb(y)], dim=-1) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) diff --git a/audioldm/latent_diffusion/util.py b/audioldm/latent_diffusion/util.py new file mode 100755 index 0000000000000000000000000000000000000000..8b289f6aa7f22a070870d8a706f944dc8547e936 --- /dev/null +++ b/audioldm/latent_diffusion/util.py @@ -0,0 +1,295 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from audioldm.utils import instantiate_from_config + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t).contiguous() + return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/audioldm/ldm.py b/audioldm/ldm.py new file mode 100755 index 0000000000000000000000000000000000000000..5ad9ef16a2704ad6660a05d33d433a1e79593866 --- /dev/null +++ b/audioldm/ldm.py @@ -0,0 +1,824 @@ +import os + +import torch +import numpy as np +from tqdm import tqdm +import os +try: + from audioldm.utils import default, instantiate_from_config, save_wave + from audioldm.latent_diffusion.ddpm import DDPM + from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution + from audioldm.latent_diffusion.util import noise_like + from audioldm.latent_diffusion.ddim import DDIMSampler +except ModuleNotFoundError: + from .utils import default, instantiate_from_config, save_wave + from .latent_diffusion.ddpm import DDPM + from .variational_autoencoder.distributions import DiagonalGaussianDistribution + from .latent_diffusion.util import noise_like + from .latent_diffusion.ddim import DDIMSampler + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__( + self, + device="cuda", + first_stage_config=None, + cond_stage_config=None, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + base_learning_rate=None, + *args, + **kwargs, + ): + self.device = device + self.learning_rate = base_learning_rate + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs["timesteps"] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = "concat" if concat_mode else "crossattn" + if cond_stage_config == "__is_unconditional__": + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.cond_stage_key_orig = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer("scale_factor", torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule( + given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s + ) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != "__is_first_stage__" + assert config != "__is_unconditional__" + model = instantiate_from_config(config) + self.cond_stage_model = model + self.cond_stage_model = self.cond_stage_model.to(self.device) + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, "encode") and callable( + self.cond_stage_model.encode + ): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + # Text input is list + if type(c) == list and len(c) == 1: + c = self.cond_stage_model([c[0], c[0]]) + c = c[0:1] + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + @torch.no_grad() + def get_input( + self, + batch, + k, + return_first_stage_encode=True, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + ): + x = super().get_input(batch, k) + + if bs is not None: + x = x[:bs] + + x = x.to(self.device) + + if return_first_stage_encode: + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + else: + z = None + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ["caption", "coordinates_bbox"]: + xc = batch[cond_key] + elif cond_key == "class_label": + xc = batch + else: + # [bs, 1, 527] + xc = super().get_input(batch, cond_key) + if type(xc) == torch.Tensor: + xc = xc.to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + + if bs is not None: + c = c[:bs] + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {"pos_x": pos_x, "pos_y": pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, "b h w c -> b c h w").contiguous() + + z = 1.0 / self.scale_factor * z + return self.first_stage_model.decode(z) + + def mel_spectrogram_to_waveform(self, mel): + # Mel: [bs, 1, t-steps, fbins] + if len(mel.size()) == 4: + mel = mel.squeeze(1) + mel = mel.permute(0, 2, 1) + waveform = self.first_stage_model.vocoder(mel) + waveform = waveform.cpu().detach().numpy() + return waveform + + @torch.no_grad() + def encode_first_stage(self, x): + return self.first_stage_model.encode(x) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + if self.model.conditioning_key == "concat": + key = "c_concat" + elif self.model.conditioning_key == "crossattn": + key = "c_crossattn" + else: + key = "c_film" + + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score( + self, model_out, x, t, c, **corrector_kwargs + ) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = ( + (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() + ) + + if return_codebook_ids: + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance + ).exp() * noise, logits.argmax(dim=1) + if return_x0: + return ( + model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, + x0, + ) + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm( + reversed(range(0, timesteps)), + desc="Progressive Generation", + total=timesteps, + ) + if verbose + else reversed(range(0, timesteps)) + ) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + ) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): + if shape is None: + shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + **kwargs, + ) + + @torch.no_grad() + def sample_log( + self, + cond, + batch_size, + ddim, + ddim_steps, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_plms=False, + mask=None, + **kwargs, + ): + + if mask is not None: + shape = (self.channels, mask.size()[-2], mask.size()[-1]) + else: + shape = (self.channels, self.latent_t_size, self.latent_f_size) + + intermediate = None + if ddim and not use_plms: + # print("Use ddim sampler") + + ddim_sampler = DDIMSampler(self) + samples, intermediates = ddim_sampler.sample( + ddim_steps, + batch_size, + shape, + cond, + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask=mask, + **kwargs, + ) + + else: + # print("Use DDPM sampler") + samples, intermediates = self.sample( + cond=cond, + batch_size=batch_size, + return_intermediates=True, + unconditional_guidance_scale=unconditional_guidance_scale, + mask=mask, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + + return samples, intermediate + + @torch.no_grad() + def generate_sample( + self, + batchs, + ddim_steps=200, + ddim_eta=1.0, + x_T=None, + n_candidate_gen_per_text=1, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + name="waveform", + use_plms=False, + save=False, + **kwargs, + ): + # Generate n_candidate_gen_per_text times and select the best + # Batch: audio, text, fnames + assert x_T is None + try: + batchs = iter(batchs) + except TypeError: + raise ValueError("The first input argument should be an iterable object") + + if use_plms: + assert ddim_steps is not None + use_ddim = ddim_steps is not None + # waveform_save_path = os.path.join(self.get_log_dir(), name) + # os.makedirs(waveform_save_path, exist_ok=True) + # print("Waveform save path: ", waveform_save_path) + + with self.ema_scope("Generate"): + for batch in batchs: + z, c = self.get_input( + batch, + self.first_stage_key, + cond_key=self.cond_stage_key, + return_first_stage_outputs=False, + force_c_encode=True, + return_original_cond=False, + bs=None, + ) + text = super().get_input(batch, "text") + + # Generate multiple samples + batch_size = z.shape[0] * n_candidate_gen_per_text + c = torch.cat([c] * n_candidate_gen_per_text, dim=0) + text = text * n_candidate_gen_per_text + + if unconditional_guidance_scale != 1.0: + unconditional_conditioning = ( + self.cond_stage_model.get_unconditional_condition(batch_size) + ) + + samples, _ = self.sample_log( + cond=c, + batch_size=batch_size, + x_T=x_T, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + use_plms=use_plms, + ) + + if(torch.max(torch.abs(samples)) > 1e2): + samples = torch.clip(samples, min=-10, max=10) + + mel = self.decode_first_stage(samples) + + waveform = self.mel_spectrogram_to_waveform(mel) + + if waveform.shape[0] > 1: + similarity = self.cond_stage_model.cos_similarity( + torch.FloatTensor(waveform).squeeze(1), text + ) + + best_index = [] + for i in range(z.shape[0]): + candidates = similarity[i :: z.shape[0]] + max_index = torch.argmax(candidates).item() + best_index.append(i + max_index * z.shape[0]) + + waveform = waveform[best_index] + # print("Similarity between generated audio and text", similarity) + # print("Choose the following indexes:", best_index) + + return waveform + + @torch.no_grad() + def generate_sample_masked( + self, + batchs, + ddim_steps=200, + ddim_eta=1.0, + x_T=None, + n_candidate_gen_per_text=1, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + name="waveform", + use_plms=False, + time_mask_ratio_start_and_end=(0.25, 0.75), + freq_mask_ratio_start_and_end=(0.75, 1.0), + save=False, + **kwargs, + ): + # Generate n_candidate_gen_per_text times and select the best + # Batch: audio, text, fnames + assert x_T is None + try: + batchs = iter(batchs) + except TypeError: + raise ValueError("The first input argument should be an iterable object") + + if use_plms: + assert ddim_steps is not None + use_ddim = ddim_steps is not None + # waveform_save_path = os.path.join(self.get_log_dir(), name) + # os.makedirs(waveform_save_path, exist_ok=True) + # print("Waveform save path: ", waveform_save_path) + + with self.ema_scope("Generate"): + for batch in batchs: + z, c = self.get_input( + batch, + self.first_stage_key, + cond_key=self.cond_stage_key, + return_first_stage_outputs=False, + force_c_encode=True, + return_original_cond=False, + bs=None, + ) + text = super().get_input(batch, "text") + + # Generate multiple samples + batch_size = z.shape[0] * n_candidate_gen_per_text + + _, h, w = z.shape[0], z.shape[2], z.shape[3] + + mask = torch.ones(batch_size, h, w).to(self.device) + + mask[:, int(h * time_mask_ratio_start_and_end[0]) : int(h * time_mask_ratio_start_and_end[1]), :] = 0 + mask[:, :, int(w * freq_mask_ratio_start_and_end[0]) : int(w * freq_mask_ratio_start_and_end[1])] = 0 + mask = mask[:, None, ...] + + c = torch.cat([c] * n_candidate_gen_per_text, dim=0) + text = text * n_candidate_gen_per_text + + if unconditional_guidance_scale != 1.0: + unconditional_conditioning = ( + self.cond_stage_model.get_unconditional_condition(batch_size) + ) + + samples, _ = self.sample_log( + cond=c, + batch_size=batch_size, + x_T=x_T, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + use_plms=use_plms, mask=mask, x0=torch.cat([z] * n_candidate_gen_per_text) + ) + + mel = self.decode_first_stage(samples) + + waveform = self.mel_spectrogram_to_waveform(mel) + + if waveform.shape[0] > 1: + similarity = self.cond_stage_model.cos_similarity( + torch.FloatTensor(waveform).squeeze(1), text + ) + + best_index = [] + for i in range(z.shape[0]): + candidates = similarity[i :: z.shape[0]] + max_index = torch.argmax(candidates).item() + best_index.append(i + max_index * z.shape[0]) + + waveform = waveform[best_index] + # print("Similarity between generated audio and text", similarity) + # print("Choose the following indexes:", best_index) + + return waveform \ No newline at end of file diff --git a/audioldm/pipeline.py b/audioldm/pipeline.py new file mode 100755 index 0000000000000000000000000000000000000000..b08e1f77206483025ce027588c2dea1de78ae26c --- /dev/null +++ b/audioldm/pipeline.py @@ -0,0 +1,301 @@ +import os + +import argparse +import yaml +import torch +from torch import autocast +from tqdm import tqdm, trange + +from audioldm import LatentDiffusion, seed_everything +from audioldm.utils import default_audioldm_config, get_duration, get_bit_depth, get_metadata, download_checkpoint +from audioldm.audio import wav_to_fbank, TacotronSTFT, read_wav_file +from audioldm.latent_diffusion.ddim import DDIMSampler +from einops import repeat +import os + +def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1): + text = [text] * batchsize + if batchsize < 1: + print("Warning: Batchsize must be at least 1. Batchsize is set to .") + + if(fbank is None): + fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format + else: + fbank = torch.FloatTensor(fbank) + fbank = fbank.expand(batchsize, 1024, 64) + assert fbank.size(0) == batchsize + + stft = torch.zeros((batchsize, 1024, 512)) # Not used + + if(waveform is None): + waveform = torch.zeros((batchsize, 160000)) # Not used + else: + waveform = torch.FloatTensor(waveform) + waveform = waveform.expand(batchsize, -1) + assert waveform.size(0) == batchsize + + fname = [""] * batchsize # Not used + + batch = ( + fbank, + stft, + None, + fname, + waveform, + text, + ) + return batch + +def round_up_duration(duration): + return int(round(duration/2.5) + 1) * 2.5 + +def build_model( + ckpt_path=None, + config=None, + model_name="audioldm-s-full" +): + print("Load AudioLDM: %s", model_name) + + if(ckpt_path is None): + ckpt_path = get_metadata()[model_name]["path"] + + if(not os.path.exists(ckpt_path)): + download_checkpoint(model_name) + + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + + if config is not None: + assert type(config) is str + config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) + else: + config = default_audioldm_config(model_name) + + # Use text as condition instead of using waveform during training + config["model"]["params"]["device"] = device + config["model"]["params"]["cond_stage_key"] = "text" + + # No normalization here + latent_diffusion = LatentDiffusion(**config["model"]["params"]) + + resume_from_checkpoint = ckpt_path + + checkpoint = torch.load(resume_from_checkpoint, map_location=device) + latent_diffusion.load_state_dict(checkpoint["state_dict"]) + + latent_diffusion.eval() + latent_diffusion = latent_diffusion.to(device) + + latent_diffusion.cond_stage_model.embed_mode = "text" + return latent_diffusion + +def duration_to_latent_t_size(duration): + return int(duration * 25.6) + +def set_cond_audio(latent_diffusion): + latent_diffusion.cond_stage_key = "waveform" + latent_diffusion.cond_stage_model.embed_mode="audio" + return latent_diffusion + +def set_cond_text(latent_diffusion): + latent_diffusion.cond_stage_key = "text" + latent_diffusion.cond_stage_model.embed_mode="text" + return latent_diffusion + +def text_to_audio( + latent_diffusion, + text, + original_audio_file_path = None, + seed=42, + ddim_steps=200, + duration=10, + batchsize=1, + guidance_scale=2.5, + n_candidate_gen_per_text=3, + config=None, +): + seed_everything(int(seed)) + waveform = None + if(original_audio_file_path is not None): + waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160) + + batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize) + + latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) + + if(waveform is not None): + print("Generate audio that has similar content as %s" % original_audio_file_path) + latent_diffusion = set_cond_audio(latent_diffusion) + else: + print("Generate audio using text %s" % text) + latent_diffusion = set_cond_text(latent_diffusion) + + with torch.no_grad(): + waveform = latent_diffusion.generate_sample( + [batch], + unconditional_guidance_scale=guidance_scale, + ddim_steps=ddim_steps, + n_candidate_gen_per_text=n_candidate_gen_per_text, + duration=duration, + ) + return waveform + +def style_transfer( + latent_diffusion, + text, + original_audio_file_path, + transfer_strength, + seed=42, + duration=10, + batchsize=1, + guidance_scale=2.5, + ddim_steps=200, + config=None, +): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + + assert original_audio_file_path is not None, "You need to provide the original audio file path" + + audio_file_duration = get_duration(original_audio_file_path) + + assert get_bit_depth(original_audio_file_path) == 16, "The bit depth of the original audio file %s must be 16" % original_audio_file_path + + # if(duration > 20): + # print("Warning: The duration of the audio file %s must be less than 20 seconds. Longer duration will result in Nan in model output (we are still debugging that); Automatically set duration to 20 seconds") + # duration = 20 + + if(duration >= audio_file_duration): + print("Warning: Duration you specified %s-seconds must equal or smaller than the audio file duration %ss" % (duration, audio_file_duration)) + duration = round_up_duration(audio_file_duration) + print("Set new duration as %s-seconds" % duration) + + # duration = round_up_duration(duration) + + latent_diffusion = set_cond_text(latent_diffusion) + + if config is not None: + assert type(config) is str + config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) + else: + config = default_audioldm_config() + + seed_everything(int(seed)) + # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) + latent_diffusion.cond_stage_model.embed_mode = "text" + + fn_STFT = TacotronSTFT( + config["preprocessing"]["stft"]["filter_length"], + config["preprocessing"]["stft"]["hop_length"], + config["preprocessing"]["stft"]["win_length"], + config["preprocessing"]["mel"]["n_mel_channels"], + config["preprocessing"]["audio"]["sampling_rate"], + config["preprocessing"]["mel"]["mel_fmin"], + config["preprocessing"]["mel"]["mel_fmax"], + ) + + mel, _, _ = wav_to_fbank( + original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT + ) + mel = mel.unsqueeze(0).unsqueeze(0).to(device) + mel = repeat(mel, "1 ... -> b ...", b=batchsize) + init_latent = latent_diffusion.get_first_stage_encoding( + latent_diffusion.encode_first_stage(mel) + ) # move to latent space, encode and sample + if(torch.max(torch.abs(init_latent)) > 1e2): + init_latent = torch.clip(init_latent, min=-10, max=10) + sampler = DDIMSampler(latent_diffusion) + sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False) + + t_enc = int(transfer_strength * ddim_steps) + prompts = text + + with torch.no_grad(): + with autocast("cuda"): + with latent_diffusion.ema_scope(): + uc = None + if guidance_scale != 1.0: + uc = latent_diffusion.cond_stage_model.get_unconditional_condition( + batchsize + ) + + c = latent_diffusion.get_learned_conditioning([prompts] * batchsize) + z_enc = sampler.stochastic_encode( + init_latent, torch.tensor([t_enc] * batchsize).to(device) + ) + samples = sampler.decode( + z_enc, + c, + t_enc, + unconditional_guidance_scale=guidance_scale, + unconditional_conditioning=uc, + ) + # x_samples = latent_diffusion.decode_first_stage(samples) # Will result in Nan in output + # print(torch.sum(torch.isnan(samples))) + x_samples = latent_diffusion.decode_first_stage(samples) + # print(x_samples) + x_samples = latent_diffusion.decode_first_stage(samples[:,:,:-3,:]) + # print(x_samples) + waveform = latent_diffusion.first_stage_model.decode_to_waveform( + x_samples + ) + + return waveform + +def super_resolution_and_inpainting( + latent_diffusion, + text, + original_audio_file_path = None, + seed=42, + ddim_steps=200, + duration=None, + batchsize=1, + guidance_scale=2.5, + n_candidate_gen_per_text=3, + time_mask_ratio_start_and_end=(0.10, 0.15), # regenerate the 10% to 15% of the time steps in the spectrogram + # time_mask_ratio_start_and_end=(1.0, 1.0), # no inpainting + # freq_mask_ratio_start_and_end=(0.75, 1.0), # regenerate the higher 75% to 100% mel bins + freq_mask_ratio_start_and_end=(1.0, 1.0), # no super-resolution + config=None, +): + seed_everything(int(seed)) + if config is not None: + assert type(config) is str + config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) + else: + config = default_audioldm_config() + fn_STFT = TacotronSTFT( + config["preprocessing"]["stft"]["filter_length"], + config["preprocessing"]["stft"]["hop_length"], + config["preprocessing"]["stft"]["win_length"], + config["preprocessing"]["mel"]["n_mel_channels"], + config["preprocessing"]["audio"]["sampling_rate"], + config["preprocessing"]["mel"]["mel_fmin"], + config["preprocessing"]["mel"]["mel_fmax"], + ) + + # waveform = read_wav_file(original_audio_file_path, None) + mel, _, _ = wav_to_fbank( + original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT + ) + + batch = make_batch_for_text_to_audio(text, fbank=mel[None,...], batchsize=batchsize) + + # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) + latent_diffusion = set_cond_text(latent_diffusion) + + with torch.no_grad(): + waveform = latent_diffusion.generate_sample_masked( + [batch], + unconditional_guidance_scale=guidance_scale, + ddim_steps=ddim_steps, + n_candidate_gen_per_text=n_candidate_gen_per_text, + duration=duration, + time_mask_ratio_start_and_end=time_mask_ratio_start_and_end, + freq_mask_ratio_start_and_end=freq_mask_ratio_start_and_end + ) + return waveform \ No newline at end of file diff --git a/audioldm/utils.py b/audioldm/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..5401b29d4366774233f1bf4a9e7fcb7ce214187e --- /dev/null +++ b/audioldm/utils.py @@ -0,0 +1,281 @@ +import contextlib +import importlib + +from inspect import isfunction +import os +import soundfile as sf +import time +import wave + +import urllib.request +import progressbar + +CACHE_DIR = os.getenv( + "AUDIOLDM_CACHE_DIR", + os.path.join(os.path.expanduser("~"), ".cache/audioldm")) + +def get_duration(fname): + with contextlib.closing(wave.open(fname, 'r')) as f: + frames = f.getnframes() + rate = f.getframerate() + return frames / float(rate) + +def get_bit_depth(fname): + with contextlib.closing(wave.open(fname, 'r')) as f: + bit_depth = f.getsampwidth() * 8 + return bit_depth + +def get_time(): + t = time.localtime() + return time.strftime("%d_%m_%Y_%H_%M_%S", t) + +def seed_everything(seed): + import random, os + import numpy as np + import torch + + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + + +def save_wave(waveform, savepath, name="outwav"): + if type(name) is not list: + name = [name] * waveform.shape[0] + + for i in range(waveform.shape[0]): + path = os.path.join( + savepath, + "%s_%s.wav" + % ( + os.path.basename(name[i]) + if (not ".wav" in name[i]) + else os.path.basename(name[i]).split(".")[0], + i, + ), + ) + print("Save audio to %s" % path) + sf.write(path, waveform[i, 0], samplerate=16000) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def default_audioldm_config(model_name="audioldm-s-full"): + basic_config = { + "wave_file_save_path": "./output", + "id": { + "version": "v1", + "name": "default", + "root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml", + }, + "preprocessing": { + "audio": {"sampling_rate": 16000, "max_wav_value": 32768}, + "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024}, + "mel": { + "n_mel_channels": 64, + "mel_fmin": 0, + "mel_fmax": 8000, + "freqm": 0, + "timem": 0, + "blur": False, + "mean": -4.63, + "std": 2.74, + "target_length": 1024, + }, + }, + "model": { + "device": "cuda", + "target": "audioldm.pipline.LatentDiffusion", + "params": { + "base_learning_rate": 5e-06, + "linear_start": 0.0015, + "linear_end": 0.0195, + "num_timesteps_cond": 1, + "log_every_t": 200, + "timesteps": 1000, + "first_stage_key": "fbank", + "cond_stage_key": "waveform", + "latent_t_size": 256, + "latent_f_size": 16, + "channels": 8, + "cond_stage_trainable": True, + "conditioning_key": "film", + "monitor": "val/loss_simple_ema", + "scale_by_std": True, + "unet_config": { + "target": "audioldm.latent_diffusion.openaimodel.UNetModel", + "params": { + "image_size": 64, + "extra_film_condition_dim": 512, + "extra_film_use_concat": True, + "in_channels": 8, + "out_channels": 8, + "model_channels": 128, + "attention_resolutions": [8, 4, 2], + "num_res_blocks": 2, + "channel_mult": [1, 2, 3, 5], + "num_head_channels": 32, + "use_spatial_transformer": True, + }, + }, + "first_stage_config": { + "base_learning_rate": 4.5e-05, + "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL", + "params": { + "monitor": "val/rec_loss", + "image_key": "fbank", + "subband": 1, + "embed_dim": 8, + "time_shuffle": 1, + "ddconfig": { + "double_z": True, + "z_channels": 8, + "resolution": 256, + "downsample_time": False, + "in_channels": 1, + "out_ch": 1, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + }, + }, + }, + "cond_stage_config": { + "target": "audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2", + "params": { + "key": "waveform", + "sampling_rate": 16000, + "embed_mode": "audio", + "unconditional_prob": 0.1, + }, + }, + }, + }, + } + + if("-l-" in model_name): + basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 256 + basic_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = 64 + elif("-m-" in model_name): + basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 192 + basic_config["model"]["params"]["cond_stage_config"]["params"]["amodel"] = "HTSAT-base" # This model use a larger HTAST + + return basic_config + +def get_metadata(): + return { + "audioldm-s-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm-s-full.ckpt", + ), + "url": "https://zenodo.org/record/7600541/files/audioldm-s-full?download=1", + }, + "audioldm-l-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm-l-full.ckpt", + ), + "url": "https://zenodo.org/record/7698295/files/audioldm-full-l.ckpt?download=1", + }, + "audioldm-s-full-v2": { + "path": os.path.join( + CACHE_DIR, + "audioldm-s-full-v2.ckpt", + ), + "url": "https://zenodo.org/record/7698295/files/audioldm-full-s-v2.ckpt?download=1", + }, + "audioldm-m-text-ft": { + "path": os.path.join( + CACHE_DIR, + "audioldm-m-text-ft.ckpt", + ), + "url": "https://zenodo.org/record/7813012/files/audioldm-m-text-ft.ckpt?download=1", + }, + "audioldm-s-text-ft": { + "path": os.path.join( + CACHE_DIR, + "audioldm-s-text-ft.ckpt", + ), + "url": "https://zenodo.org/record/7813012/files/audioldm-s-text-ft.ckpt?download=1", + }, + "audioldm-m-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm-m-full.ckpt", + ), + "url": "https://zenodo.org/record/7813012/files/audioldm-m-full.ckpt?download=1", + }, + } + +class MyProgressBar(): + def __init__(self): + self.pbar = None + + def __call__(self, block_num, block_size, total_size): + if not self.pbar: + self.pbar=progressbar.ProgressBar(maxval=total_size) + self.pbar.start() + + downloaded = block_num * block_size + if downloaded < total_size: + self.pbar.update(downloaded) + else: + self.pbar.finish() + +def download_checkpoint(checkpoint_name="audioldm-s-full"): + meta = get_metadata() + if(checkpoint_name not in meta.keys()): + print("The model name you provided is not supported. Please use one of the following: ", meta.keys()) + + if not os.path.exists(meta[checkpoint_name]["path"]) or os.path.getsize(meta[checkpoint_name]["path"]) < 2*10**9: + os.makedirs(os.path.dirname(meta[checkpoint_name]["path"]), exist_ok=True) + print(f"Downloading the main structure of {checkpoint_name} into {os.path.dirname(meta[checkpoint_name]['path'])}") + + urllib.request.urlretrieve(meta[checkpoint_name]["url"], meta[checkpoint_name]["path"], MyProgressBar()) + print( + "Weights downloaded in: {} Size: {}".format( + meta[checkpoint_name]["path"], + os.path.getsize(meta[checkpoint_name]["path"]), + ) + ) + \ No newline at end of file diff --git a/audioldm/variational_autoencoder/__init__.py b/audioldm/variational_autoencoder/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm/variational_autoencoder/autoencoder.py b/audioldm/variational_autoencoder/autoencoder.py new file mode 100755 index 0000000000000000000000000000000000000000..9cf409edff3d6440e076f169cac10f6d0c157a95 --- /dev/null +++ b/audioldm/variational_autoencoder/autoencoder.py @@ -0,0 +1,103 @@ +import torch +from audioldm.latent_diffusion.ema import * +from audioldm.variational_autoencoder.modules import Encoder, Decoder +from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution + +from audioldm.hifigan.utilities import get_vocoder, vocoder_infer + + +class AutoencoderKL(nn.Module): + def __init__( + self, + ddconfig=None, + lossconfig=None, + image_key="fbank", + embed_dim=None, + time_shuffle=1, + subband=1, + ckpt_path=None, + reload_from_ckpt=None, + ignore_keys=[], + colorize_nlabels=None, + monitor=None, + base_learning_rate=1e-5, + ): + super().__init__() + + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + + self.subband = int(subband) + + if self.subband > 1: + print("Use subband decomposition %s" % self.subband) + + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + + self.vocoder = get_vocoder(None, "cpu") + self.embed_dim = embed_dim + + if monitor is not None: + self.monitor = monitor + + self.time_shuffle = time_shuffle + self.reload_from_ckpt = reload_from_ckpt + self.reloaded = False + self.mean, self.std = None, None + + def encode(self, x): + # x = self.time_shuffle_operation(x) + x = self.freq_split_subband(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + dec = self.freq_merge_subband(dec) + return dec + + def decode_to_waveform(self, dec): + dec = dec.squeeze(1).permute(0, 2, 1) + wav_reconstruction = vocoder_infer(dec, self.vocoder) + return wav_reconstruction + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + + if self.flag_first_run: + print("Latent size: ", z.size()) + self.flag_first_run = False + + dec = self.decode(z) + + return dec, posterior + + def freq_split_subband(self, fbank): + if self.subband == 1 or self.image_key != "stft": + return fbank + + bs, ch, tstep, fbins = fbank.size() + + assert fbank.size(-1) % self.subband == 0 + assert ch == 1 + + return ( + fbank.squeeze(1) + .reshape(bs, tstep, self.subband, fbins // self.subband) + .permute(0, 2, 1, 3) + ) + + def freq_merge_subband(self, subband_fbank): + if self.subband == 1 or self.image_key != "stft": + return subband_fbank + assert subband_fbank.size(1) == self.subband # Channel dimension + bs, sub_ch, tstep, fbins = subband_fbank.size() + return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1) diff --git a/audioldm/variational_autoencoder/distributions.py b/audioldm/variational_autoencoder/distributions.py new file mode 100755 index 0000000000000000000000000000000000000000..58eb535e7769f402169ddff77ee45c96ba3650d9 --- /dev/null +++ b/audioldm/variational_autoencoder/distributions.py @@ -0,0 +1,102 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/audioldm/variational_autoencoder/modules.py b/audioldm/variational_autoencoder/modules.py new file mode 100755 index 0000000000000000000000000000000000000000..e48386d045c1d0e159de33db02af1035159c3447 --- /dev/null +++ b/audioldm/variational_autoencoder/modules.py @@ -0,0 +1,1066 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from audioldm.utils import instantiate_from_config +from audioldm.latent_diffusion.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class UpsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=1, padding=2 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class DownsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + downsample_time_stride4_levels=[], + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level in self.downsample_time_stride4_levels: + down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) + else: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + downsample_time_stride4_levels=[], + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + # print("Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level - 1 in self.downsample_time_stride4_levels: + up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) + else: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x).contiguous() + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x + + +class FirstStagePostProcessor(nn.Module): + def __init__( + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, + ): + super().__init__() + if pretrained_config is None: + assert ( + pretrained_model is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert ( + pretrained_config is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d( + in_channels, n_channels, kernel_size=3, stride=1, padding=1 + ) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append( + ResnetBlock( + in_channels=ch_in, out_channels=m * n_channels, dropout=dropout + ) + ) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, "b c h w -> b (h w) c") + return z diff --git a/inversion_utils.py b/inversion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d2606a08af913705217165b4846205696575ad18 --- /dev/null +++ b/inversion_utils.py @@ -0,0 +1,450 @@ +import torch +from tqdm import tqdm +# from torchvision import transforms as T +from typing import List, Optional, Dict, Union +from models import PipelineWrapper + + +def mu_tilde(model, xt, x0, timestep): + "mu_tilde(x_t, x_0) DDPM paper eq. 7" + prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps + alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \ + else model.scheduler.final_alpha_cumprod + alpha_t = model.scheduler.alphas[timestep] + beta_t = 1 - alpha_t + alpha_bar = model.scheduler.alphas_cumprod[timestep] + return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1-alpha_bar)) * x0 + \ + ((alpha_t**0.5 * (1-alpha_prod_t_prev)) / (1 - alpha_bar)) * xt + + +def sample_xts_from_x0(model, x0, num_inference_steps=50, x_prev_mode=False): + """ + Samples from P(x_1:T|x_0) + """ + # torch.manual_seed(43256465436) + alpha_bar = model.model.scheduler.alphas_cumprod + sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5 + alphas = model.model.scheduler.alphas + # betas = 1 - alphas + variance_noise_shape = ( + num_inference_steps + 1, + model.model.unet.config.in_channels, + # model.unet.sample_size, + # model.unet.sample_size) + x0.shape[-2], + x0.shape[-1]) + + timesteps = model.model.scheduler.timesteps.to(model.device) + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + xts = torch.zeros(variance_noise_shape).to(x0.device) + xts[0] = x0 + x_prev = x0 + for t in reversed(timesteps): + # idx = t_to_idx[int(t)] + idx = num_inference_steps-t_to_idx[int(t)] + if x_prev_mode: + xts[idx] = x_prev * (alphas[t] ** 0.5) + torch.randn_like(x0) * ((1-alphas[t]) ** 0.5) + x_prev = xts[idx].clone() + else: + xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t] + # xts = torch.cat([xts, x0 ],dim = 0) + + return xts + + +def forward_step(model, model_output, timestep, sample): + next_timestep = min(model.scheduler.config.num_train_timesteps - 2, + timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps) + + # 2. compute alphas, betas + alpha_prod_t = model.scheduler.alphas_cumprod[timestep] + # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 \ + # else self.scheduler.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + + # 5. TODO: simple noising implementatiom + next_sample = model.scheduler.add_noise(pred_original_sample, model_output, torch.LongTensor([next_timestep])) + return next_sample + + +def inversion_forward_process(model: PipelineWrapper, + x0: torch.Tensor, + etas: Optional[float] = None, + prog_bar: bool = False, + prompts: List[str] = [""], + cfg_scales: List[float] = [3.5], + num_inference_steps: int = 50, + eps: Optional[float] = None, + cutoff_points: Optional[List[float]] = None, + numerical_fix: bool = False, + extract_h_space: bool = False, + extract_skipconns: bool = False, + x_prev_mode: bool = False): + if len(prompts) > 1 and extract_h_space: + raise NotImplementedError("How do you split cfg_scales for hspace? TODO") + + if len(prompts) > 1 or prompts[0] != "": + text_embeddings_hidden_states, text_embeddings_class_labels, \ + text_embeddings_boolean_prompt_mask = model.encode_text(prompts) + # text_embeddings = encode_text(model, prompt) + + # # classifier free guidance + batch_size = len(prompts) + cfg_scales_tensor = torch.ones((batch_size, *x0.shape[1:]), device=model.device, dtype=x0.dtype) + + # if len(prompts) > 1: + # if cutoff_points is None: + # cutoff_points = [i * 1 / batch_size for i in range(1, batch_size)] + # if len(cfg_scales) == 1: + # cfg_scales *= batch_size + # elif len(cfg_scales) < batch_size: + # raise ValueError("Not enough target CFG scales") + + # cutoff_points = [int(x * cfg_scales_tensor.shape[2]) for x in cutoff_points] + # cutoff_points = [0, *cutoff_points, cfg_scales_tensor.shape[2]] + + # for i, (start, end) in enumerate(zip(cutoff_points[:-1], cutoff_points[1:])): + # cfg_scales_tensor[i, :, end:] = 0 + # cfg_scales_tensor[i, :, :start] = 0 + # cfg_scales_tensor[i] *= cfg_scales[i] + # if prompts[i] == "": + # cfg_scales_tensor[i] = 0 + # cfg_scales_tensor = T.functional.gaussian_blur(cfg_scales_tensor, kernel_size=15, sigma=1) + # else: + cfg_scales_tensor *= cfg_scales[0] + + uncond_embedding_hidden_states, uncond_embedding_class_lables, uncond_boolean_prompt_mask = model.encode_text([""]) + # uncond_embedding = encode_text(model, "") + timesteps = model.model.scheduler.timesteps.to(model.device) + variance_noise_shape = ( + num_inference_steps, + model.model.unet.config.in_channels, + # model.unet.sample_size, + # model.unet.sample_size) + x0.shape[-2], + x0.shape[-1]) + + if etas is None or (type(etas) in [int, float] and etas == 0): + eta_is_zero = True + zs = None + else: + eta_is_zero = False + if type(etas) in [int, float]: + etas = [etas]*model.model.scheduler.num_inference_steps + xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps, x_prev_mode=x_prev_mode) + alpha_bar = model.model.scheduler.alphas_cumprod + zs = torch.zeros(size=variance_noise_shape, device=model.device) + hspaces = [] + skipconns = [] + t_to_idx = {int(v): k for k, v in enumerate(timesteps)} + xt = x0 + # op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps) + op = tqdm(timesteps) if prog_bar else timesteps + + for t in op: + # idx = t_to_idx[int(t)] + idx = num_inference_steps - t_to_idx[int(t)] - 1 + # 1. predict noise residual + if not eta_is_zero: + xt = xts[idx+1][None] + + with torch.no_grad(): + out, out_hspace, out_skipconns = model.unet_forward(xt, timestep=t, + encoder_hidden_states=uncond_embedding_hidden_states, + class_labels=uncond_embedding_class_lables, + encoder_attention_mask=uncond_boolean_prompt_mask) + # out = model.unet.forward(xt, timestep= t, encoder_hidden_states=uncond_embedding) + if len(prompts) > 1 or prompts[0] != "": + cond_out, cond_out_hspace, cond_out_skipconns = model.unet_forward( + xt.expand(len(prompts), -1, -1, -1), timestep=t, + encoder_hidden_states=text_embeddings_hidden_states, + class_labels=text_embeddings_class_labels, + encoder_attention_mask=text_embeddings_boolean_prompt_mask) + # cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states = text_embeddings) + + if len(prompts) > 1 or prompts[0] != "": + # # classifier free guidance + noise_pred = out.sample + \ + (cfg_scales_tensor * (cond_out.sample - out.sample.expand(batch_size, -1, -1, -1)) + ).sum(axis=0).unsqueeze(0) + if extract_h_space or extract_skipconns: + noise_h_space = out_hspace + cfg_scales[0] * (cond_out_hspace - out_hspace) + if extract_skipconns: + noise_skipconns = {k: [out_skipconns[k][j] + cfg_scales[0] * + (cond_out_skipconns[k][j] - out_skipconns[k][j]) + for j in range(len(out_skipconns[k]))] + for k in out_skipconns} + else: + noise_pred = out.sample + if extract_h_space or extract_skipconns: + noise_h_space = out_hspace + if extract_skipconns: + noise_skipconns = out_skipconns + if extract_h_space or extract_skipconns: + hspaces.append(noise_h_space) + if extract_skipconns: + skipconns.append(noise_skipconns) + + if eta_is_zero: + # 2. compute more noisy image and set x_t -> x_t+1 + xt = forward_step(model.model, noise_pred, t, xt) + else: + # xtm1 = xts[idx+1][None] + xtm1 = xts[idx][None] + # pred of x0 + if model.model.scheduler.config.prediction_type == 'epsilon': + pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5 + elif model.model.scheduler.config.prediction_type == 'v_prediction': + pred_original_sample = (alpha_bar[t] ** 0.5) * xt - ((1 - alpha_bar[t]) ** 0.5) * noise_pred + + # direction to xt + prev_timestep = t - model.model.scheduler.config.num_train_timesteps // \ + model.model.scheduler.num_inference_steps + + alpha_prod_t_prev = model.get_alpha_prod_t_prev(prev_timestep) + variance = model.get_variance(t, prev_timestep) + + if model.model.scheduler.config.prediction_type == 'epsilon': + radom_noise_pred = noise_pred + elif model.model.scheduler.config.prediction_type == 'v_prediction': + radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt + + pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * radom_noise_pred + + mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5) + + zs[idx] = z + + # correction to avoid error accumulation + if numerical_fix: + xtm1 = mu_xt + (etas[idx] * variance ** 0.5)*z + xts[idx] = xtm1 + + if zs is not None: + # zs[-1] = torch.zeros_like(zs[-1]) + zs[0] = torch.zeros_like(zs[0]) + # zs_cycle[0] = torch.zeros_like(zs[0]) + + if extract_h_space: + hspaces = torch.concat(hspaces, axis=0) + return xt, zs, xts, hspaces + + if extract_skipconns: + hspaces = torch.concat(hspaces, axis=0) + return xt, zs, xts, hspaces, skipconns + + return xt, zs, xts + + +def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None): + # 1. get previous step value (=t-1) + prev_timestep = timestep - model.model.scheduler.config.num_train_timesteps // \ + model.model.scheduler.num_inference_steps + # 2. compute alphas, betas + alpha_prod_t = model.model.scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = model.get_alpha_prod_t_prev(prev_timestep) + beta_prod_t = 1 - alpha_prod_t + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if model.model.scheduler.config.prediction_type == 'epsilon': + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif model.model.scheduler.config.prediction_type == 'v_prediction': + pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + # variance = self.scheduler._get_variance(timestep, prev_timestep) + variance = model.get_variance(timestep, prev_timestep) + # std_dev_t = eta * variance ** (0.5) + # Take care of asymetric reverse process (asyrp) + if model.model.scheduler.config.prediction_type == 'epsilon': + model_output_direction = model_output + elif model.model.scheduler.config.prediction_type == 'v_prediction': + model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction + pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + # 8. Add noice if eta > 0 + if eta > 0: + if variance_noise is None: + variance_noise = torch.randn(model_output.shape, device=model.device) + sigma_z = eta * variance ** (0.5) * variance_noise + prev_sample = prev_sample + sigma_z + + return prev_sample + + +def inversion_reverse_process(model: PipelineWrapper, + xT: torch.Tensor, + skips: torch.Tensor, + fix_alpha: float = 0.1, + etas: float = 0, + prompts: List[str] = [""], + neg_prompts: List[str] = [""], + cfg_scales: Optional[List[float]] = None, + prog_bar: bool = False, + zs: Optional[List[torch.Tensor]] = None, + # controller=None, + cutoff_points: Optional[List[float]] = None, + hspace_add: Optional[torch.Tensor] = None, + hspace_replace: Optional[torch.Tensor] = None, + skipconns_replace: Optional[Dict[int, torch.Tensor]] = None, + zero_out_resconns: Optional[Union[int, List]] = None, + asyrp: bool = False, + extract_h_space: bool = False, + extract_skipconns: bool = False): + + batch_size = len(prompts) + + text_embeddings_hidden_states, text_embeddings_class_labels, \ + text_embeddings_boolean_prompt_mask = model.encode_text(prompts) + uncond_embedding_hidden_states, uncond_embedding_class_lables, \ + uncond_boolean_prompt_mask = model.encode_text(neg_prompts) + # text_embeddings = encode_text(model, prompts) + # uncond_embedding = encode_text(model, [""] * batch_size) + + masks = torch.ones((batch_size, *xT.shape[1:]), device=model.device, dtype=xT.dtype) + cfg_scales_tensor = torch.ones((batch_size, *xT.shape[1:]), device=model.device, dtype=xT.dtype) + + # if batch_size > 1: + # if cutoff_points is None: + # cutoff_points = [i * 1 / batch_size for i in range(1, batch_size)] + # if len(cfg_scales) == 1: + # cfg_scales *= batch_size + # elif len(cfg_scales) < batch_size: + # raise ValueError("Not enough target CFG scales") + + # cutoff_points = [int(x * cfg_scales_tensor.shape[2]) for x in cutoff_points] + # cutoff_points = [0, *cutoff_points, cfg_scales_tensor.shape[2]] + + # for i, (start, end) in enumerate(zip(cutoff_points[:-1], cutoff_points[1:])): + # cfg_scales_tensor[i, :, end:] = 0 + # cfg_scales_tensor[i, :, :start] = 0 + # masks[i, :, end:] = 0 + # masks[i, :, :start] = 0 + # cfg_scales_tensor[i] *= cfg_scales[i] + # cfg_scales_tensor = T.functional.gaussian_blur(cfg_scales_tensor, kernel_size=15, sigma=1) + # masks = T.functional.gaussian_blur(masks, kernel_size=15, sigma=1) + # else: + cfg_scales_tensor *= cfg_scales[0] + + if etas is None: + etas = 0 + if type(etas) in [int, float]: + etas = [etas]*model.model.scheduler.num_inference_steps + assert len(etas) == model.model.scheduler.num_inference_steps + timesteps = model.model.scheduler.timesteps.to(model.device) + + # xt = xT.expand(1, -1, -1, -1) + xt = xT[skips.max()].unsqueeze(0) + op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:] + + t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])} + hspaces = [] + skipconns = [] + + for it, t in enumerate(op): + # idx = t_to_idx[int(t)] + idx = model.model.scheduler.num_inference_steps - t_to_idx[int(t)] - \ + (model.model.scheduler.num_inference_steps - zs.shape[0] + 1) + # # Unconditional embedding + with torch.no_grad(): + uncond_out, out_hspace, out_skipconns = model.unet_forward( + xt, timestep=t, + encoder_hidden_states=uncond_embedding_hidden_states, + class_labels=uncond_embedding_class_lables, + encoder_attention_mask=uncond_boolean_prompt_mask, + mid_block_additional_residual=(None if hspace_add is None else + (1 / (cfg_scales[0] + 1)) * + (hspace_add[-zs.shape[0]:][it] if hspace_add.shape[0] > 1 + else hspace_add)), + replace_h_space=(None if hspace_replace is None else + (hspace_replace[-zs.shape[0]:][it].unsqueeze(0) if hspace_replace.shape[0] > 1 + else hspace_replace)), + zero_out_resconns=zero_out_resconns, + replace_skip_conns=(None if skipconns_replace is None else + (skipconns_replace[-zs.shape[0]:][it] if len(skipconns_replace) > 1 + else skipconns_replace)) + ) # encoder_hidden_states = uncond_embedding) + + # # Conditional embedding + if prompts: + with torch.no_grad(): + cond_out, cond_out_hspace, cond_out_skipconns = model.unet_forward( + xt.expand(batch_size, -1, -1, -1), + timestep=t, + encoder_hidden_states=text_embeddings_hidden_states, + class_labels=text_embeddings_class_labels, + encoder_attention_mask=text_embeddings_boolean_prompt_mask, + mid_block_additional_residual=(None if hspace_add is None else + (cfg_scales[0] / (cfg_scales[0] + 1)) * + (hspace_add[-zs.shape[0]:][it] if hspace_add.shape[0] > 1 + else hspace_add)), + replace_h_space=(None if hspace_replace is None else + (hspace_replace[-zs.shape[0]:][it].unsqueeze(0) if hspace_replace.shape[0] > 1 + else hspace_replace)), + zero_out_resconns=zero_out_resconns, + replace_skip_conns=(None if skipconns_replace is None else + (skipconns_replace[-zs.shape[0]:][it] if len(skipconns_replace) > 1 + else skipconns_replace)) + ) # encoder_hidden_states = text_embeddings) + + z = zs[idx] if zs is not None else None + # print(f'idx: {idx}') + # print(f't: {t}') + z = z.unsqueeze(0) + # z = z.expand(batch_size, -1, -1, -1) + if prompts: + # # classifier free guidance + # noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample) + noise_pred = uncond_out.sample + \ + (cfg_scales_tensor * (cond_out.sample - uncond_out.sample.expand(batch_size, -1, -1, -1)) + ).sum(axis=0).unsqueeze(0) + if extract_h_space or extract_skipconns: + noise_h_space = out_hspace + cfg_scales[0] * (cond_out_hspace - out_hspace) + if extract_skipconns: + noise_skipconns = {k: [out_skipconns[k][j] + cfg_scales[0] * + (cond_out_skipconns[k][j] - out_skipconns[k][j]) + for j in range(len(out_skipconns[k]))] + for k in out_skipconns} + else: + noise_pred = uncond_out.sample + if extract_h_space or extract_skipconns: + noise_h_space = out_hspace + if extract_skipconns: + noise_skipconns = out_skipconns + + if extract_h_space or extract_skipconns: + hspaces.append(noise_h_space) + if extract_skipconns: + skipconns.append(noise_skipconns) + + # 2. compute less noisy image and set x_t -> x_t-1 + xt = reverse_step(model, noise_pred, t, xt, eta=etas[idx], variance_noise=z) + # if controller is not None: + # xt = controller.step_callback(xt) + + # "fix" xt + apply_fix = ((skips.max() - skips) > it) + if apply_fix.any(): + apply_fix = (apply_fix * fix_alpha).unsqueeze(1).unsqueeze(2).unsqueeze(3).to(xT.device) + xt = (masks * (xt.expand(batch_size, -1, -1, -1) * (1 - apply_fix) + + apply_fix * xT[skips.max() - it - 1].expand(batch_size, -1, -1, -1)) + ).sum(axis=0).unsqueeze(0) + + if extract_h_space: + return xt, zs, torch.concat(hspaces, axis=0) + + if extract_skipconns: + return xt, zs, torch.concat(hspaces, axis=0), skipconns + + return xt, zs diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..574ad9cb8f5b2c07b6edebabd1b62ac61344343c --- /dev/null +++ b/models.py @@ -0,0 +1,644 @@ +import torch +from diffusers import DDIMScheduler +from diffusers import AudioLDM2Pipeline +from transformers import RobertaTokenizer, RobertaTokenizerFast +from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput +from typing import Any, Dict, List, Optional, Tuple, Union + + +class PipelineWrapper(torch.nn.Module): + def __init__(self, model_id, device, double_precision=False, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.model_id = model_id + self.device = device + self.double_precision = double_precision + + def get_sigma(self, timestep) -> float: + sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod - 1) + return sqrt_recipm1_alphas_cumprod[timestep] + + def load_scheduler(self): + pass + + def get_fn_STFT(self): + pass + + def vae_encode(self, x: torch.Tensor): + pass + + def vae_decode(self, x: torch.Tensor): + pass + + def decode_to_mel(self, x: torch.Tensor): + pass + + def encode_text(self, prompts: List[str]) -> Tuple: + pass + + def get_variance(self, timestep, prev_timestep): + pass + + def get_alpha_prod_t_prev(self, prev_timestep): + pass + + def unet_forward(self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + replace_h_space: Optional[torch.Tensor] = None, + replace_skip_conns: Optional[Dict[int, torch.Tensor]] = None, + return_dict: bool = True, + zero_out_resconns: Optional[Union[int, List]] = None) -> Tuple: + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.model.unet.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.model.unet.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.model.unet.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.model.unet.time_embedding(t_emb, timestep_cond) + + if self.model.unet.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.model.unet.config.class_embed_type == "timestep": + class_labels = self.model.unet.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.model.unet.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.model.unet.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.model.unet.config.addition_embed_type == "text": + aug_emb = self.model.unet.add_embedding(encoder_hidden_states) + emb = emb + aug_emb + elif self.model.unet.config.addition_embed_type == "text_image": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.model.unet.__class__} has the config param `addition_embed_type` set to 'text_image' " + f"which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + + aug_emb = self.model.unet.add_embedding(text_embs, image_embs) + emb = emb + aug_emb + + if self.model.unet.time_embed_act is not None: + emb = self.model.unet.time_embed_act(emb) + + if self.model.unet.encoder_hid_proj is not None and self.model.unet.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.model.unet.encoder_hid_proj(encoder_hidden_states) + elif self.model.unet.encoder_hid_proj is not None and \ + self.model.unet.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.model.unet.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' " + f"which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.model.unet.encoder_hid_proj(encoder_hidden_states, image_embeds) + + # 2. pre-process + sample = self.model.unet.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.model.unet.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.model.unet.mid_block is not None: + sample = self.model.unet.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + # print(sample.shape) + + if replace_h_space is None: + h_space = sample.clone() + else: + h_space = replace_h_space + sample = replace_h_space.clone() + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + extracted_res_conns = {} + # 5. up + for i, upsample_block in enumerate(self.model.unet.up_blocks): + is_final_block = i == len(self.model.unet.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + if replace_skip_conns is not None and replace_skip_conns.get(i): + res_samples = replace_skip_conns.get(i) + + if zero_out_resconns is not None: + if (type(zero_out_resconns) is int and i >= (zero_out_resconns - 1)) or \ + type(zero_out_resconns) is list and i in zero_out_resconns: + res_samples = [torch.zeros_like(x) for x in res_samples] + # down_block_res_samples = [torch.zeros_like(x) for x in down_block_res_samples] + + extracted_res_conns[i] = res_samples + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + if self.model.unet.conv_norm_out: + sample = self.model.unet.conv_norm_out(sample) + sample = self.model.unet.conv_act(sample) + sample = self.model.unet.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns + + +class AudioLDM2Wrapper(PipelineWrapper): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + if self.double_precision: + self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, torch_dtype=torch.float64).to(self.device) + else: + try: + self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True).to(self.device) + except FileNotFoundError: + self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=False).to(self.device) + + def load_scheduler(self): + # self.model.scheduler = DDIMScheduler.from_config(self.model_id, subfolder="scheduler") + self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler") + + def get_fn_STFT(self): + from audioldm.audio import TacotronSTFT + return TacotronSTFT( + filter_length=1024, + hop_length=160, + win_length=1024, + n_mel_channels=64, + sampling_rate=16000, + mel_fmin=0, + mel_fmax=8000, + ) + + def vae_encode(self, x): + # self.model.vae.disable_tiling() + if x.shape[2] % 4: + x = torch.nn.functional.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0)) + return (self.model.vae.encode(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float() + # return (self.encode_no_tiling(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float() + + def vae_decode(self, x): + return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample + + def decode_to_mel(self, x): + if self.double_precision: + tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().double()).detach() + tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().float()).detach() + if len(tmp.shape) == 1: + tmp = tmp.unsqueeze(0) + return tmp + + def encode_text(self, prompts: List[str]): + tokenizers = [self.model.tokenizer, self.model.tokenizer_2] + text_encoders = [self.model.text_encoder, self.model.text_encoder_2] + prompt_embeds_list = [] + attention_mask_list = [] + + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompts, + padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True, + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] \ + and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1: -1]) + print(f"The following part of your input was truncated because {text_encoder.config.model_type} can " + f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) + + with torch.no_grad(): + if text_encoder.config.model_type == "clap": + prompt_embeds = text_encoder.get_text_features( + text_input_ids, + attention_mask=attention_mask, + ) + # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size) + prompt_embeds = prompt_embeds[:, None, :] + # make sure that we attend to this single hidden-state + attention_mask = attention_mask.new_ones((len(prompts), 1)) + else: + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds_list.append(prompt_embeds) + attention_mask_list.append(attention_mask) + + # print(f'prompt[0].shape: {prompt_embeds_list[0].shape}') + # print(f'prompt[1].shape: {prompt_embeds_list[1].shape}') + # print(f'attn[0].shape: {attention_mask_list[0].shape}') + # print(f'attn[1].shape: {attention_mask_list[1].shape}') + + projection_output = self.model.projection_model( + hidden_states=prompt_embeds_list[0], + hidden_states_1=prompt_embeds_list[1], + attention_mask=attention_mask_list[0], + attention_mask_1=attention_mask_list[1], + ) + projected_prompt_embeds = projection_output.hidden_states + projected_attention_mask = projection_output.attention_mask + + generated_prompt_embeds = self.model.generate_language_model( + projected_prompt_embeds, + attention_mask=projected_attention_mask, + max_new_tokens=None, + ) + + prompt_embeds = prompt_embeds.to(dtype=self.model.text_encoder_2.dtype, device=self.device) + attention_mask = ( + attention_mask.to(device=self.device) + if attention_mask is not None + else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=self.device) + ) + generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.model.language_model.dtype, device=self.device) + + return generated_prompt_embeds, prompt_embeds, attention_mask + + def get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + return variance + + def get_alpha_prod_t_prev(self, prev_timestep): + return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \ + else self.model.scheduler.final_alpha_cumprod + + def unet_forward(self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + timestep_cond: Optional[torch.Tensor] = None, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + replace_h_space: Optional[torch.Tensor] = None, + replace_skip_conns: Optional[Dict[int, torch.Tensor]] = None, + zero_out_resconns: Optional[Union[int, List]] = None) -> Tuple: + + # Translation + encoder_hidden_states_1 = class_labels + class_labels = None + encoder_attention_mask_1 = encoder_attention_mask + encoder_attention_mask = None + + # return self.model.unet(sample, timestep, + # encoder_hidden_states=generated_prompt_embeds, + # encoder_hidden_states_1=encoder_hidden_states_1, + # encoder_attention_mask_1=encoder_attention_mask_1, + # ), None, None + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2 ** self.model.unet.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # print("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if encoder_attention_mask_1 is not None: + encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0 + encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.model.unet.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.model.unet.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.model.unet.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.model.unet.config.class_embed_type == "timestep": + class_labels = self.model.unet.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.model.unet.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.model.unet.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.model.unet.time_embed_act is not None: + emb = self.model.unet.time_embed_act(emb) + + # 2. pre-process + sample = self.model.unet.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.model.unet.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states_1=encoder_hidden_states_1, + encoder_attention_mask_1=encoder_attention_mask_1, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. mid + if self.model.unet.mid_block is not None: + sample = self.model.unet.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states_1=encoder_hidden_states_1, + encoder_attention_mask_1=encoder_attention_mask_1, + ) + + if replace_h_space is None: + h_space = sample.clone() + else: + h_space = replace_h_space + sample = replace_h_space.clone() + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + extracted_res_conns = {} + # 5. up + for i, upsample_block in enumerate(self.model.unet.up_blocks): + is_final_block = i == len(self.model.unet.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + if replace_skip_conns is not None and replace_skip_conns.get(i): + res_samples = replace_skip_conns.get(i) + + if zero_out_resconns is not None: + if (type(zero_out_resconns) is int and i >= (zero_out_resconns - 1)) or \ + type(zero_out_resconns) is list and i in zero_out_resconns: + res_samples = [torch.zeros_like(x) for x in res_samples] + # down_block_res_samples = [torch.zeros_like(x) for x in down_block_res_samples] + + extracted_res_conns[i] = res_samples + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + encoder_hidden_states_1=encoder_hidden_states_1, + encoder_attention_mask_1=encoder_attention_mask_1, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 6. post-process + if self.model.unet.conv_norm_out: + sample = self.model.unet.conv_norm_out(sample) + sample = self.model.unet.conv_act(sample) + sample = self.model.unet.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns + + def forward(self, *args, **kwargs): + return self + + +def load_model(model_id, device, num_diffusion_steps, double_precision=False): + ldm_stable = AudioLDM2Wrapper(model_id=model_id, device=device, double_precision=double_precision) + ldm_stable.load_scheduler() + ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device) + torch.cuda.empty_cache() + # controller = AttentionStore() + # controller = EmptyControl() + # register_attention_control(ldm_stable.model, controller) + # return ldm_stable, controller + return ldm_stable diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7fb3722dcd2f7a0b55a4f6ac5f5b7b2e7d824254 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +torch +torchaudio +diffusers +accelerate +transformers +tqdm +soundfile +progressbar +einops +scipy +librosa==0.9.2 diff --git a/style.css b/style.css new file mode 100644 index 0000000000000000000000000000000000000000..1d6eedacc8cb83b31c9c0072b37a886e32308f63 --- /dev/null +++ b/style.css @@ -0,0 +1,4 @@ +.gradio-container { + max-width: 1050px !important; + padding-top: 1.5rem !important; +} diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..52a337c2250026cca334f26ae9357d912495e642 --- /dev/null +++ b/utils.py @@ -0,0 +1,71 @@ +import numpy as np +import torch +from typing import Optional, List, Tuple, NamedTuple, Union +from models import PipelineWrapper + + +class PromptEmbeddings(NamedTuple): + embedding_hidden_states: torch.Tensor + embedding_class_lables: torch.Tensor + boolean_prompt_mask: torch.Tensor + + +def load_audio(audio_path: Union[str, np.array], fn_STFT, left: int = 0, right: int = 0, device: Optional[torch.device] = None + ) -> torch.tensor: + if type(audio_path) is str: + import audioldm + import audioldm.audio + + duration = audioldm.utils.get_duration(audio_path) + + mel, _, _ = audioldm.audio.wav_to_fbank(audio_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT) + mel = mel.unsqueeze(0) + else: + mel = audio_path + + c, h, w = mel.shape + left = min(left, w-1) + right = min(right, w - left - 1) + mel = mel[:, :, left:w-right] + mel = mel.unsqueeze(0).to(device) + + return mel + + +def get_height_of_spectrogram(length: int, ldm_stable: PipelineWrapper) -> int: + vocoder_upsample_factor = np.prod(ldm_stable.model.vocoder.config.upsample_rates) / \ + ldm_stable.model.vocoder.config.sampling_rate + + if length is None: + length = ldm_stable.model.unet.config.sample_size * ldm_stable.model.vae_scale_factor * \ + vocoder_upsample_factor + + height = int(length / vocoder_upsample_factor) + + # original_waveform_length = int(length * ldm_stable.model.vocoder.config.sampling_rate) + if height % ldm_stable.model.vae_scale_factor != 0: + height = int(np.ceil(height / ldm_stable.model.vae_scale_factor)) * ldm_stable.model.vae_scale_factor + print( + f"Audio length in seconds {length} is increased to {height * vocoder_upsample_factor} " + f"so that it can be handled by the model. It will be cut to {length} after the " + f"denoising process." + ) + + return height + + +def get_text_embeddings(target_prompt: List[str], target_neg_prompt: List[str], ldm_stable: PipelineWrapper + ) -> Tuple[torch.Tensor, PromptEmbeddings, PromptEmbeddings]: + text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = \ + ldm_stable.encode_text(target_prompt) + uncond_embedding_hidden_states, uncond_embedding_class_lables, uncond_boolean_prompt_mask = \ + ldm_stable.encode_text(target_neg_prompt) + + text_emb = PromptEmbeddings(embedding_hidden_states=text_embeddings_hidden_states, + boolean_prompt_mask=text_embeddings_boolean_prompt_mask, + embedding_class_lables=text_embeddings_class_labels) + uncond_emb = PromptEmbeddings(embedding_hidden_states=uncond_embedding_hidden_states, + boolean_prompt_mask=uncond_boolean_prompt_mask, + embedding_class_lables=uncond_embedding_class_lables) + + return text_embeddings_class_labels, text_emb, uncond_emb