diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..83cfd8dbb643612f79f25d84b65ac7e4b3c4fb7f 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.wav filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..2f78cf5b66514f2506d9af5f3dadf3dee7aa6d9f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+*.pyc
+
diff --git a/Examples/Beethoven.wav b/Examples/Beethoven.wav
new file mode 100755
index 0000000000000000000000000000000000000000..5b2d61c0fbebcbb1ef2e040cc975cf45af248337
--- /dev/null
+++ b/Examples/Beethoven.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:30a6a087a9e0eb87422aa3b48ad966eabb1dfe105d73a25d356b71d3aee31493
+size 4828972
diff --git a/Examples/Beethoven_arcade.wav b/Examples/Beethoven_arcade.wav
new file mode 100644
index 0000000000000000000000000000000000000000..8a6bf3f4e642681ce95e7e8c604862d22f61046a
--- /dev/null
+++ b/Examples/Beethoven_arcade.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ccd929b93c15706f2102a27973d490a84ce0eb97faba6a92ece0c6d81ed2c26e
+size 1794746
diff --git a/Examples/Beethoven_piano.wav b/Examples/Beethoven_piano.wav
new file mode 100644
index 0000000000000000000000000000000000000000..0dedc751b223a8209344c92129a7ca45760f08f1
--- /dev/null
+++ b/Examples/Beethoven_piano.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5787c31b0b3c78dec33d651d437364785713042e7cfce2290cf4baf01f65ac6f
+size 1794746
diff --git a/Examples/Cat.wav b/Examples/Cat.wav
new file mode 100644
index 0000000000000000000000000000000000000000..bb849ac6f86dadcb51af1992415b4428901e7b77
--- /dev/null
+++ b/Examples/Cat.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:27b43763a8d9ac90dc78285ed9817b16524f24b4f4d1aa399616f1a04d4a9fd9
+size 1920508
diff --git a/Examples/Cat_dog.wav b/Examples/Cat_dog.wav
new file mode 100644
index 0000000000000000000000000000000000000000..3e400307510e72d7edf889c6fc17737351c64387
--- /dev/null
+++ b/Examples/Cat_dog.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:90a97dc229eeccef307dd40db97fb09cc439ce0b45a320fd84b2ea6b03d0deb2
+size 327822
diff --git a/Examples/ModalJazz.wav b/Examples/ModalJazz.wav
new file mode 100644
index 0000000000000000000000000000000000000000..a29553d312f8bd8071512c8c85ff5ead79047c61
--- /dev/null
+++ b/Examples/ModalJazz.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:846a77046d21ebc3996841404eede9d56797c82b3414025e1ccafe586eaf2959
+size 9153322
diff --git a/Examples/ModalJazz_banjo.wav b/Examples/ModalJazz_banjo.wav
new file mode 100644
index 0000000000000000000000000000000000000000..99a51a2b14e03075aecf18a7ea5f9f0acafde473
--- /dev/null
+++ b/Examples/ModalJazz_banjo.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:122e0078c0bf2fc96425071706fe0e8674c93cc1d2787fd02c0e2c0f12de5cc5
+size 6802106
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..1532c5b27dbedfcaf7eeaf9e8f76a340bf61e43a
--- /dev/null
+++ b/app.py
@@ -0,0 +1,271 @@
+import gradio as gr
+import random
+import torch
+import torchaudio
+from torch import inference_mode
+from tempfile import NamedTemporaryFile
+import numpy as np
+from models import load_model
+import utils
+from inversion_utils import inversion_forward_process, inversion_reverse_process
+
+
+def randomize_seed_fn(seed, randomize_seed):
+ if randomize_seed:
+ seed = random.randint(0, np.iinfo(np.int32).max)
+ torch.manual_seed(seed)
+ return seed
+
+
+def invert(x0, prompt_src, num_diffusion_steps, cfg_scale_src): # , ldm_stable):
+ ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
+
+ with inference_mode():
+ w0 = ldm_stable.vae_encode(x0)
+
+ # find Zs and wts - forward process
+ _, zs, wts = inversion_forward_process(ldm_stable, w0, etas=1,
+ prompts=[prompt_src],
+ cfg_scales=[cfg_scale_src],
+ prog_bar=True,
+ num_inference_steps=num_diffusion_steps,
+ numerical_fix=True)
+ return zs, wts
+
+
+def sample(zs, wts, steps, prompt_tar, tstart, cfg_scale_tar): # , ldm_stable):
+ # reverse process (via Zs and wT)
+ tstart = torch.tensor(tstart, dtype=torch.int)
+ skip = steps - tstart
+ w0, _ = inversion_reverse_process(ldm_stable, xT=wts, skips=steps - skip,
+ etas=1., prompts=[prompt_tar],
+ neg_prompts=[""], cfg_scales=[cfg_scale_tar],
+ prog_bar=True,
+ zs=zs[:int(steps - skip)])
+
+ # vae decode image
+ with inference_mode():
+ x0_dec = ldm_stable.vae_decode(w0)
+ if x0_dec.dim() < 4:
+ x0_dec = x0_dec[None, :, :, :]
+
+ with torch.no_grad():
+ audio = ldm_stable.decode_to_mel(x0_dec)
+
+ f = NamedTemporaryFile("wb", suffix=".wav", delete=False)
+ torchaudio.save(f.name, audio, sample_rate=16000)
+
+ return f.name
+
+
+def edit(input_audio,
+ model_id: str,
+ do_inversion: bool,
+ wts: gr.State, zs: gr.State, saved_inv_model: str,
+ source_prompt="",
+ target_prompt="",
+ steps=200,
+ cfg_scale_src=3.5,
+ cfg_scale_tar=12,
+ t_start=90,
+ randomize_seed=True):
+
+ global ldm_stable, current_loaded_model
+ print(f'current loaded model: {ldm_stable.model_id}')
+ if model_id != current_loaded_model:
+ print(f'Changing model to {model_id}...')
+ current_loaded_model = model_id
+ ldm_stable = None
+ ldm_stable = load_model(model_id, device, steps)
+
+ # If the inversion was done for a different model, we need to re-run the inversion
+ if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id):
+ do_inversion = True
+
+ x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device)
+
+ if do_inversion or randomize_seed: # always re-run inversion
+ zs_tensor, wts_tensor = invert(x0=x0, prompt_src=source_prompt,
+ num_diffusion_steps=steps,
+ cfg_scale_src=cfg_scale_src)
+ wts = gr.State(value=wts_tensor)
+ zs = gr.State(value=zs_tensor)
+ saved_inv_model = model_id
+ do_inversion = False
+
+ output = sample(zs.value, wts.value, steps, prompt_tar=target_prompt, tstart=t_start,
+ cfg_scale_tar=cfg_scale_tar)
+
+ return output, wts, zs, saved_inv_model, do_inversion
+
+
+current_loaded_model = "cvssp/audioldm2-music"
+# current_loaded_model = "cvssp/audioldm2-music"
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ldm_stable = load_model(current_loaded_model, device, 200) # deafult model
+
+
+def get_example():
+ case = [
+ ['Examples/Beethoven.wav',
+ '',
+ 'A recording of an arcade game soundtrack.',
+ 90,
+ 'cvssp/audioldm2-music',
+ '27s',
+ 'Examples/Beethoven_arcade.wav',
+ ],
+ ['Examples/Beethoven.wav',
+ 'A high quality recording of wind instruments and strings playing.',
+ 'A high quality recording of a piano playing.',
+ 90,
+ 'cvssp/audioldm2-music',
+ '27s',
+ 'Examples/Beethoven_piano.wav',
+ ],
+ ['Examples/ModalJazz.wav',
+ 'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.',
+ 'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.',
+ 90,
+ 'cvssp/audioldm2-music',
+ '106s',
+ 'Examples/ModalJazz_banjo.wav',],
+ ['Examples/Cat.wav',
+ '',
+ 'A dog barking.',
+ 150,
+ 'cvssp/audioldm2-large',
+ '10s',
+ 'Examples/Cat_dog.wav',]
+ ]
+ return case
+
+
+intro = """
+
Zero-Shot Text-Based Audio Editing Using DDPM Inversion
+
+
+Demo for the text-based editing method introduced in:
+ Zero-Shot Unsupervised and Text-Based Audio Editing Using DDPM Inversion
+
+
+Instructions:
+Provide an input audio and a target prompt to edit the audio.
+Tstart is used to control the tradeoff between fidelity to the original signal and text-adhearance.
+Lower value -> favor fidelity. Higher value -> apply a stronger edit.
+Make sure that you use an AudioLDM2 version that is suitable for your input audio.
+For example, use the music version for music and the large version for general audio.
+
+
+You can additionally provide a source prompt to guide even further the editing process.
+
+Longer input will take more time.
+
+For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
+
+
+
+
+"""
+
+with gr.Blocks(css='style.css') as demo:
+ def reset_do_inversion():
+ do_inversion = gr.State(value=True)
+ return do_inversion
+
+ gr.HTML(intro)
+ wts = gr.State()
+ zs = gr.State()
+ saved_inv_model = gr.State()
+ # current_loaded_model = gr.State(value="cvssp/audioldm2-music")
+ # ldm_stable = load_model("cvssp/audioldm2-music", device, 200)
+ # ldm_stable = gr.State(value=ldm_stable)
+ do_inversion = gr.State(value=True) # To save some runtime when editing the same thing over and over
+
+ with gr.Row():
+ with gr.Column():
+ src_prompt = gr.Textbox(label="OPTIONAL: Source Prompt", lines=2, interactive=True,
+ placeholder="Optional: Describe the original audio input",)
+ input_audio = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Input Audio",
+ interactive=True, scale=1)
+
+ with gr.Column():
+ tar_prompt = gr.Textbox(label="Target Prompt", placeholder="Describe your desired edited output",
+ lines=2, interactive=True)
+ output_audio = gr.Audio(label="Edited Audio", interactive=False, scale=1)
+
+ with gr.Row():
+ with gr.Column():
+ submit = gr.Button("Edit")
+
+ with gr.Row():
+ t_start = gr.Slider(minimum=30, maximum=160, value=110, step=1, label="T-start", interactive=True, scale=3,
+ info="Higher T-start -> stronger edit. Lower T-start -> more similar to original audio.")
+ model_id = gr.Dropdown(label="AudioLDM2 Version", choices=["cvssp/audioldm2",
+ "cvssp/audioldm2-large",
+ "cvssp/audioldm2-music"],
+ info="Choose a checkpoint suitable for your intended audio and edit.",
+ value="cvssp/audioldm2-music", interactive=True, type="value", scale=2)
+ with gr.Accordion("More Options", open=False):
+
+ with gr.Row():
+ cfg_scale_src = gr.Number(value=3, minimum=0.5, maximum=25, precision=None,
+ label="Source Guidance Scale", interactive=True, scale=1)
+ cfg_scale_tar = gr.Number(value=12, minimum=0.5, maximum=25, precision=None,
+ label="Target Guidance Scale", interactive=True, scale=1)
+ steps = gr.Number(value=200, precision=0, minimum=20, maximum=1000,
+ label="Num Diffusion Steps", interactive=True, scale=1)
+ with gr.Row():
+ seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
+ randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
+ length = gr.Number(label="Length", interactive=False, visible=False)
+
+ def change_tstart_range(steps):
+ t_start.maximum = int(160/200 * steps)
+ t_start.minimum = int(30/200 * steps)
+ if t_start.value > t_start.maximum:
+ t_start.value = t_start.maximum
+ if t_start.value < t_start.minimum:
+ t_start.value = t_start.minimum
+ return t_start
+
+ submit.click(
+ fn=randomize_seed_fn,
+ inputs=[seed, randomize_seed],
+ outputs=[seed], queue=False).then(
+ fn=edit,
+ inputs=[input_audio,
+ model_id,
+ do_inversion,
+ # current_loaded_model, ldm_stable,
+ wts, zs, saved_inv_model,
+ src_prompt,
+ tar_prompt,
+ steps,
+ cfg_scale_src,
+ cfg_scale_tar,
+ t_start,
+ randomize_seed
+ ],
+ outputs=[output_audio, wts, zs, saved_inv_model, do_inversion] # , current_loaded_model, ldm_stable],
+ )
+
+ # If sources changed we have to rerun inversion
+ input_audio.change(fn=reset_do_inversion, outputs=[do_inversion])
+ src_prompt.change(fn=reset_do_inversion, outputs=[do_inversion])
+ model_id.change(fn=reset_do_inversion, outputs=[do_inversion])
+ steps.change(fn=change_tstart_range, inputs=[steps], outputs=[t_start])
+
+ gr.Examples(
+ label="Examples",
+ examples=get_example(),
+ inputs=[input_audio, src_prompt, tar_prompt, t_start, model_id, length, output_audio],
+ outputs=[output_audio]
+ )
+
+ demo.queue()
+ demo.launch()
diff --git a/audioldm/__init__.py b/audioldm/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..2bbf85f01ccc72b6f18e7405d940adf07a26b500
--- /dev/null
+++ b/audioldm/__init__.py
@@ -0,0 +1,8 @@
+from .ldm import LatentDiffusion
+from .utils import seed_everything, save_wave, get_time, get_duration
+from .pipeline import *
+
+
+
+
+
diff --git a/audioldm/__main__.py b/audioldm/__main__.py
new file mode 100755
index 0000000000000000000000000000000000000000..fd597a810642bf518f7bbbc113e134526d918ba7
--- /dev/null
+++ b/audioldm/__main__.py
@@ -0,0 +1,183 @@
+#!/usr/bin/python3
+import os
+from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration
+import argparse
+
+CACHE_DIR = os.getenv(
+ "AUDIOLDM_CACHE_DIR",
+ os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument(
+ "--mode",
+ type=str,
+ required=False,
+ default="generation",
+ help="generation: text-to-audio generation; transfer: style transfer",
+ choices=["generation", "transfer"]
+)
+
+parser.add_argument(
+ "-t",
+ "--text",
+ type=str,
+ required=False,
+ default="",
+ help="Text prompt to the model for audio generation",
+)
+
+parser.add_argument(
+ "-f",
+ "--file_path",
+ type=str,
+ required=False,
+ default=None,
+ help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio",
+)
+
+parser.add_argument(
+ "--transfer_strength",
+ type=float,
+ required=False,
+ default=0.5,
+ help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text",
+)
+
+parser.add_argument(
+ "-s",
+ "--save_path",
+ type=str,
+ required=False,
+ help="The path to save model output",
+ default="./output",
+)
+
+parser.add_argument(
+ "--model_name",
+ type=str,
+ required=False,
+ help="The checkpoint you gonna use",
+ default="audioldm-m-full",
+ choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2","audioldm-m-text-ft", "audioldm-s-text-ft", "audioldm-m-full"]
+)
+
+parser.add_argument(
+ "-ckpt",
+ "--ckpt_path",
+ type=str,
+ required=False,
+ help="The path to the pretrained .ckpt model",
+ default=None,
+)
+
+parser.add_argument(
+ "-b",
+ "--batchsize",
+ type=int,
+ required=False,
+ default=1,
+ help="Generate how many samples at the same time",
+)
+
+parser.add_argument(
+ "--ddim_steps",
+ type=int,
+ required=False,
+ default=200,
+ help="The sampling step for DDIM",
+)
+
+parser.add_argument(
+ "-gs",
+ "--guidance_scale",
+ type=float,
+ required=False,
+ default=2.5,
+ help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
+)
+
+parser.add_argument(
+ "-dur",
+ "--duration",
+ type=float,
+ required=False,
+ default=10.0,
+ help="The duration of the samples",
+)
+
+parser.add_argument(
+ "-n",
+ "--n_candidate_gen_per_text",
+ type=int,
+ required=False,
+ default=3,
+ help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
+)
+
+parser.add_argument(
+ "--seed",
+ type=int,
+ required=False,
+ default=42,
+ help="Change this value (any integer number) will lead to a different generation result.",
+)
+
+args = parser.parse_args()
+
+if(args.ckpt_path is not None):
+ print("Warning: ckpt_path has no effect after version 0.0.20.")
+
+assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5"
+
+mode = args.mode
+if(mode == "generation" and args.file_path is not None):
+ mode = "generation_audio_to_audio"
+ if(len(args.text) > 0):
+ print("Warning: You have specified the --file_path. --text will be ignored")
+ args.text = ""
+
+save_path = os.path.join(args.save_path, mode)
+
+if(args.file_path is not None):
+ save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0]))
+
+text = args.text
+random_seed = args.seed
+duration = args.duration
+guidance_scale = args.guidance_scale
+n_candidate_gen_per_text = args.n_candidate_gen_per_text
+
+os.makedirs(save_path, exist_ok=True)
+audioldm = build_model(model_name=args.model_name)
+
+if(args.mode == "generation"):
+ waveform = text_to_audio(
+ audioldm,
+ text,
+ args.file_path,
+ random_seed,
+ duration=duration,
+ guidance_scale=guidance_scale,
+ ddim_steps=args.ddim_steps,
+ n_candidate_gen_per_text=n_candidate_gen_per_text,
+ batchsize=args.batchsize,
+ )
+
+elif(args.mode == "transfer"):
+ assert args.file_path is not None
+ assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path
+ waveform = style_transfer(
+ audioldm,
+ text,
+ args.file_path,
+ args.transfer_strength,
+ random_seed,
+ duration=duration,
+ guidance_scale=guidance_scale,
+ ddim_steps=args.ddim_steps,
+ batchsize=args.batchsize,
+ )
+ waveform = waveform[:,None,:]
+
+save_wave(waveform, save_path, name="%s_%s" % (get_time(), text))
diff --git a/audioldm/audio/__init__.py b/audioldm/audio/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..56902e96f041bc4ba6bfadd7a7742023b9560233
--- /dev/null
+++ b/audioldm/audio/__init__.py
@@ -0,0 +1,2 @@
+from .tools import wav_to_fbank, read_wav_file
+from .stft import TacotronSTFT
diff --git a/audioldm/audio/audio_processing.py b/audioldm/audio/audio_processing.py
new file mode 100755
index 0000000000000000000000000000000000000000..77a4057aa82f226f68474f4c2a19eba84510d663
--- /dev/null
+++ b/audioldm/audio/audio_processing.py
@@ -0,0 +1,100 @@
+import torch
+import numpy as np
+import librosa.util as librosa_util
+from scipy.signal import get_window
+
+
+def window_sumsquare(
+ window,
+ n_frames,
+ hop_length,
+ win_length,
+ n_fft,
+ dtype=np.float32,
+ norm=None,
+):
+ """
+ # from librosa 0.6
+ Compute the sum-square envelope of a window function at a given hop length.
+
+ This is used to estimate modulation effects induced by windowing
+ observations in short-time fourier transforms.
+
+ Parameters
+ ----------
+ window : string, tuple, number, callable, or list-like
+ Window specification, as in `get_window`
+
+ n_frames : int > 0
+ The number of analysis frames
+
+ hop_length : int > 0
+ The number of samples to advance between frames
+
+ win_length : [optional]
+ The length of the window function. By default, this matches `n_fft`.
+
+ n_fft : int > 0
+ The length of each analysis frame.
+
+ dtype : np.dtype
+ The data type of the output
+
+ Returns
+ -------
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
+ The sum-squared envelope of the window function
+ """
+ if win_length is None:
+ win_length = n_fft
+
+ n = n_fft + hop_length * (n_frames - 1)
+ x = np.zeros(n, dtype=dtype)
+
+ # Compute the squared window at the desired length
+ win_sq = get_window(window, win_length, fftbins=True)
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
+
+ # Fill the envelope
+ for i in range(n_frames):
+ sample = i * hop_length
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
+ return x
+
+
+def griffin_lim(magnitudes, stft_fn, n_iters=30):
+ """
+ PARAMS
+ ------
+ magnitudes: spectrogram magnitudes
+ stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
+ """
+
+ angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
+ angles = angles.astype(np.float32)
+ angles = torch.autograd.Variable(torch.from_numpy(angles))
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
+
+ for i in range(n_iters):
+ _, angles = stft_fn.transform(signal)
+ signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
+ return signal
+
+
+def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
+ """
+ PARAMS
+ ------
+ C: compression factor
+ """
+ return normalize_fun(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ """
+ PARAMS
+ ------
+ C: compression factor used to compress
+ """
+ return torch.exp(x) / C
diff --git a/audioldm/audio/stft.py b/audioldm/audio/stft.py
new file mode 100755
index 0000000000000000000000000000000000000000..6af485d45d5d371e7c9bc72531497fa7f7a716c1
--- /dev/null
+++ b/audioldm/audio/stft.py
@@ -0,0 +1,180 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy.signal import get_window
+from librosa.util import pad_center, tiny
+from librosa.filters import mel as librosa_mel_fn
+
+from audioldm.audio.audio_processing import (
+ dynamic_range_compression,
+ dynamic_range_decompression,
+ window_sumsquare,
+)
+
+
+class STFT(torch.nn.Module):
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
+
+ def __init__(self, filter_length, hop_length, win_length, window="hann"):
+ super(STFT, self).__init__()
+ self.filter_length = filter_length
+ self.hop_length = hop_length
+ self.win_length = win_length
+ self.window = window
+ self.forward_transform = None
+ scale = self.filter_length / self.hop_length
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
+
+ cutoff = int((self.filter_length / 2 + 1))
+ fourier_basis = np.vstack(
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
+ )
+
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
+ inverse_basis = torch.FloatTensor(
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :]
+ )
+
+ if window is not None:
+ assert filter_length >= win_length
+ # get window and zero center pad it to filter_length
+ fft_window = get_window(window, win_length, fftbins=True)
+ fft_window = pad_center(fft_window, size=filter_length)
+ fft_window = torch.from_numpy(fft_window).float()
+
+ # window the bases
+ forward_basis *= fft_window
+ inverse_basis *= fft_window
+
+ self.register_buffer("forward_basis", forward_basis.float())
+ self.register_buffer("inverse_basis", inverse_basis.float())
+
+ def transform(self, input_data):
+ num_batches = input_data.size(0)
+ num_samples = input_data.size(1)
+
+ self.num_samples = num_samples
+
+ # similar to librosa, reflect-pad the input
+ input_data = input_data.view(num_batches, 1, num_samples)
+ input_data = F.pad(
+ input_data.unsqueeze(1),
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
+ mode="reflect",
+ )
+ input_data = input_data.squeeze(1)
+
+ forward_transform = F.conv1d(
+ input_data,
+ torch.autograd.Variable(self.forward_basis, requires_grad=False),
+ stride=self.hop_length,
+ padding=0,
+ ).cpu()
+
+ cutoff = int((self.filter_length / 2) + 1)
+ real_part = forward_transform[:, :cutoff, :]
+ imag_part = forward_transform[:, cutoff:, :]
+
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
+ phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
+
+ return magnitude, phase
+
+ def inverse(self, magnitude, phase):
+ recombine_magnitude_phase = torch.cat(
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
+ )
+
+ inverse_transform = F.conv_transpose1d(
+ recombine_magnitude_phase,
+ torch.autograd.Variable(self.inverse_basis, requires_grad=False),
+ stride=self.hop_length,
+ padding=0,
+ )
+
+ if self.window is not None:
+ window_sum = window_sumsquare(
+ self.window,
+ magnitude.size(-1),
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ n_fft=self.filter_length,
+ dtype=np.float32,
+ )
+ # remove modulation effects
+ approx_nonzero_indices = torch.from_numpy(
+ np.where(window_sum > tiny(window_sum))[0]
+ )
+ window_sum = torch.autograd.Variable(
+ torch.from_numpy(window_sum), requires_grad=False
+ )
+ window_sum = window_sum
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
+ approx_nonzero_indices
+ ]
+
+ # scale by hop ratio
+ inverse_transform *= float(self.filter_length) / self.hop_length
+
+ inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
+ inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
+
+ return inverse_transform
+
+ def forward(self, input_data):
+ self.magnitude, self.phase = self.transform(input_data)
+ reconstruction = self.inverse(self.magnitude, self.phase)
+ return reconstruction
+
+
+class TacotronSTFT(torch.nn.Module):
+ def __init__(
+ self,
+ filter_length,
+ hop_length,
+ win_length,
+ n_mel_channels,
+ sampling_rate,
+ mel_fmin,
+ mel_fmax,
+ ):
+ super(TacotronSTFT, self).__init__()
+ self.n_mel_channels = n_mel_channels
+ self.sampling_rate = sampling_rate
+ self.stft_fn = STFT(filter_length, hop_length, win_length)
+ mel_basis = librosa_mel_fn(
+ sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax
+ )
+ mel_basis = torch.from_numpy(mel_basis).float()
+ self.register_buffer("mel_basis", mel_basis)
+
+ def spectral_normalize(self, magnitudes, normalize_fun):
+ output = dynamic_range_compression(magnitudes, normalize_fun)
+ return output
+
+ def spectral_de_normalize(self, magnitudes):
+ output = dynamic_range_decompression(magnitudes)
+ return output
+
+ def mel_spectrogram(self, y, normalize_fun=torch.log):
+ """Computes mel-spectrograms from a batch of waves
+ PARAMS
+ ------
+ y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
+
+ RETURNS
+ -------
+ mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
+ """
+ assert torch.min(y.data) >= -1, torch.min(y.data)
+ assert torch.max(y.data) <= 1, torch.max(y.data)
+
+ magnitudes, phases = self.stft_fn.transform(y)
+ magnitudes = magnitudes.data
+ mel_output = torch.matmul(self.mel_basis, magnitudes)
+ mel_output = self.spectral_normalize(mel_output, normalize_fun)
+ energy = torch.norm(magnitudes, dim=1)
+
+ log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
+
+ return mel_output, log_magnitudes, energy
diff --git a/audioldm/audio/tools.py b/audioldm/audio/tools.py
new file mode 100755
index 0000000000000000000000000000000000000000..d641a982664b6673822c8528a1929c593f011b11
--- /dev/null
+++ b/audioldm/audio/tools.py
@@ -0,0 +1,85 @@
+import torch
+import numpy as np
+import torchaudio
+
+
+def get_mel_from_wav(audio, _stft):
+ audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
+ audio = torch.autograd.Variable(audio, requires_grad=False)
+ melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
+ melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
+ log_magnitudes_stft = (
+ torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
+ )
+ energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
+ return melspec, log_magnitudes_stft, energy
+
+
+def _pad_spec(fbank, target_length=1024):
+ n_frames = fbank.shape[0]
+ p = target_length - n_frames
+ # cut and pad
+ if p > 0:
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
+ fbank = m(fbank)
+ elif p < 0:
+ fbank = fbank[0:target_length, :]
+
+ if fbank.size(-1) % 2 != 0:
+ fbank = fbank[..., :-1]
+
+ return fbank
+
+
+def pad_wav(waveform, segment_length):
+ waveform_length = waveform.shape[-1]
+ assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
+ if segment_length is None or waveform_length == segment_length:
+ return waveform
+ elif waveform_length > segment_length:
+ return waveform[:segment_length]
+ elif waveform_length < segment_length:
+ temp_wav = np.zeros((1, segment_length))
+ temp_wav[:, :waveform_length] = waveform
+ return temp_wav
+
+def normalize_wav(waveform):
+ waveform = waveform - np.mean(waveform)
+ waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
+ return waveform * 0.5
+
+
+def read_wav_file(filename, segment_length):
+ # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
+ waveform, sr = torchaudio.load(filename) # Faster!!!
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
+ waveform = waveform.numpy()[0, ...]
+ waveform = normalize_wav(waveform)
+ waveform = waveform[None, ...]
+ waveform = pad_wav(waveform, segment_length)
+
+ waveform = waveform / np.max(np.abs(waveform))
+ waveform = 0.5 * waveform
+
+ return waveform
+
+
+def wav_to_fbank(filename, target_length=1024, fn_STFT=None):
+ assert fn_STFT is not None
+
+ # mixup
+ waveform = read_wav_file(filename, target_length * 160) # hop size is 160
+
+ waveform = waveform[0, ...]
+ waveform = torch.FloatTensor(waveform)
+
+ fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
+
+ fbank = torch.FloatTensor(fbank.T)
+ log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)
+
+ fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
+ log_magnitudes_stft, target_length
+ )
+
+ return fbank, log_magnitudes_stft, waveform
diff --git a/audioldm/clap/__init__.py b/audioldm/clap/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm/clap/encoders.py b/audioldm/clap/encoders.py
new file mode 100755
index 0000000000000000000000000000000000000000..d82d57fc6ea7b6c74aa38db2af47abccbd4e9a00
--- /dev/null
+++ b/audioldm/clap/encoders.py
@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+from audioldm.clap.open_clip import create_model
+from audioldm.clap.training.data import get_audio_features
+import torchaudio
+from transformers import RobertaTokenizer
+import torch.nn.functional as F
+
+
+class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
+ def __init__(
+ self,
+ pretrained_path="",
+ key="class",
+ sampling_rate=16000,
+ embed_mode="audio",
+ amodel = "HTSAT-tiny",
+ unconditional_prob=0.1,
+ random_mute=False,
+ max_random_mute_portion=0.5,
+ training_mode=True,
+ ):
+ super().__init__()
+
+ self.key = key
+ self.device = "cpu"
+ self.precision = "fp32"
+ self.amodel = amodel # or 'PANN-14'
+ self.tmodel = "roberta" # the best text encoder in our training
+ self.enable_fusion = False # False if you do not want to use the fusion model
+ self.fusion_type = "aff_2d"
+ self.pretrained = pretrained_path
+ self.embed_mode = embed_mode
+ self.embed_mode_orig = embed_mode
+ self.sampling_rate = sampling_rate
+ self.unconditional_prob = unconditional_prob
+ self.random_mute = random_mute
+ self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
+ self.max_random_mute_portion = max_random_mute_portion
+ self.training_mode = training_mode
+ self.model, self.model_cfg = create_model(
+ self.amodel,
+ self.tmodel,
+ self.pretrained,
+ precision=self.precision,
+ device=self.device,
+ enable_fusion=self.enable_fusion,
+ fusion_type=self.fusion_type,
+ )
+ for p in self.model.parameters():
+ p.requires_grad = False
+
+ self.model.eval()
+
+ def get_unconditional_condition(self, batchsize):
+ self.unconditional_token = self.model.get_text_embedding(
+ self.tokenizer(["", ""])
+ )[0:1]
+ return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
+
+ def batch_to_list(self, batch):
+ ret = []
+ for i in range(batch.size(0)):
+ ret.append(batch[i])
+ return ret
+
+ def make_decision(self, probability):
+ if float(torch.rand(1)) < probability:
+ return True
+ else:
+ return False
+
+ def random_uniform(self, start, end):
+ val = torch.rand(1).item()
+ return start + (end - start) * val
+
+ def _random_mute(self, waveform):
+ # waveform: [bs, t-steps]
+ t_steps = waveform.size(-1)
+ for i in range(waveform.size(0)):
+ mute_size = int(
+ self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
+ )
+ mute_start = int(self.random_uniform(0, t_steps - mute_size))
+ waveform[i, mute_start : mute_start + mute_size] = 0
+ return waveform
+
+ def cos_similarity(self, waveform, text):
+ # waveform: [bs, t_steps]
+ with torch.no_grad():
+ self.embed_mode = "audio"
+ print(text)
+ audio_emb = self(waveform.cuda())
+ self.embed_mode = "text"
+ text_emb = self(text)
+ similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
+ return similarity.squeeze()
+
+ def forward(self, batch, key=None):
+ # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
+ # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
+ if self.model.training == True and not self.training_mode:
+ print(
+ "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
+ )
+ self.model, self.model_cfg = create_model(
+ self.amodel,
+ self.tmodel,
+ self.pretrained,
+ precision=self.precision,
+ device="cuda",
+ enable_fusion=self.enable_fusion,
+ fusion_type=self.fusion_type,
+ )
+ for p in self.model.parameters():
+ p.requires_grad = False
+ self.model.eval()
+
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
+ if self.embed_mode == "audio":
+ with torch.no_grad():
+ audio_dict_list = []
+ assert (
+ self.sampling_rate == 16000
+ ), "We only support 16000 sampling rate"
+ if self.random_mute:
+ batch = self._random_mute(batch)
+ # batch: [bs, 1, t-samples]
+ batch = torchaudio.functional.resample(
+ batch, orig_freq=self.sampling_rate, new_freq=48000
+ )
+ for waveform in self.batch_to_list(batch):
+ audio_dict = {}
+ audio_dict = get_audio_features(
+ audio_dict,
+ waveform,
+ 480000,
+ data_truncating="fusion",
+ data_filling="repeatpad",
+ audio_cfg=self.model_cfg["audio_cfg"],
+ )
+ audio_dict_list.append(audio_dict)
+ # [bs, 512]
+ embed = self.model.get_audio_embedding(audio_dict_list)
+ elif self.embed_mode == "text":
+ with torch.no_grad():
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
+ text_data = self.tokenizer(batch)
+ embed = self.model.get_text_embedding(text_data)
+
+ embed = embed.unsqueeze(1)
+ self.unconditional_token = self.model.get_text_embedding(
+ self.tokenizer(["", ""])
+ )[0:1]
+
+ for i in range(embed.size(0)):
+ if self.make_decision(self.unconditional_prob):
+ embed[i] = self.unconditional_token
+
+ # [bs, 1, 512]
+ return embed.detach()
+
+ def tokenizer(self, text):
+ result = self.tokenize(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=512,
+ return_tensors="pt",
+ )
+ return {k: v.squeeze(0) for k, v in result.items()}
diff --git a/audioldm/clap/open_clip/__init__.py b/audioldm/clap/open_clip/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e9f728f2f273be5d5fdbec6c6cc41d737176a8c0
--- /dev/null
+++ b/audioldm/clap/open_clip/__init__.py
@@ -0,0 +1,25 @@
+from .factory import (
+ list_models,
+ create_model,
+ create_model_and_transforms,
+ add_model_config,
+)
+from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
+from .model import (
+ CLAP,
+ CLAPTextCfg,
+ CLAPVisionCfg,
+ CLAPAudioCfp,
+ convert_weights_to_fp16,
+ trace_model,
+)
+from .openai import load_openai_model, list_openai_models
+from .pretrained import (
+ list_pretrained,
+ list_pretrained_tag_models,
+ list_pretrained_model_tags,
+ get_pretrained_url,
+ download_pretrained,
+)
+from .tokenizer import SimpleTokenizer, tokenize
+from .transform import image_transform
diff --git a/audioldm/clap/open_clip/bert.py b/audioldm/clap/open_clip/bert.py
new file mode 100755
index 0000000000000000000000000000000000000000..a83d96d2a77ed05198efc05837522bc88d2499cc
--- /dev/null
+++ b/audioldm/clap/open_clip/bert.py
@@ -0,0 +1,40 @@
+from transformers import BertTokenizer, BertModel
+
+tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
+model = BertModel.from_pretrained("bert-base-uncased")
+text = "Replace me by any text you'd like."
+
+
+def bert_embeddings(text):
+ # text = "Replace me by any text you'd like."
+ encoded_input = tokenizer(text, return_tensors="pt")
+ output = model(**encoded_input)
+ return output
+
+
+from transformers import RobertaTokenizer, RobertaModel
+
+tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
+model = RobertaModel.from_pretrained("roberta-base")
+text = "Replace me by any text you'd like."
+
+
+def Roberta_embeddings(text):
+ # text = "Replace me by any text you'd like."
+ encoded_input = tokenizer(text, return_tensors="pt")
+ output = model(**encoded_input)
+ return output
+
+
+from transformers import BartTokenizer, BartModel
+
+tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
+model = BartModel.from_pretrained("facebook/bart-base")
+text = "Replace me by any text you'd like."
+
+
+def bart_embeddings(text):
+ # text = "Replace me by any text you'd like."
+ encoded_input = tokenizer(text, return_tensors="pt")
+ output = model(**encoded_input)
+ return output
diff --git a/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz b/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz
new file mode 100755
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/audioldm/clap/open_clip/factory.py b/audioldm/clap/open_clip/factory.py
new file mode 100755
index 0000000000000000000000000000000000000000..6e943a50abc767dfd8e04f1c27ae4830332717cd
--- /dev/null
+++ b/audioldm/clap/open_clip/factory.py
@@ -0,0 +1,279 @@
+import json
+import logging
+import os
+import pathlib
+import re
+from copy import deepcopy
+from pathlib import Path
+
+import torch
+
+from .model import CLAP, convert_weights_to_fp16
+from .openai import load_openai_model
+from .pretrained import get_pretrained_url, download_pretrained
+from .transform import image_transform
+
+_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
+_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
+CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache/audioldm")
+
+
+
+def _natural_key(string_):
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
+
+
+def _rescan_model_configs():
+ global _MODEL_CONFIGS
+
+ config_ext = (".json",)
+ config_files = []
+ for config_path in _MODEL_CONFIG_PATHS:
+ if config_path.is_file() and config_path.suffix in config_ext:
+ config_files.append(config_path)
+ elif config_path.is_dir():
+ for ext in config_ext:
+ config_files.extend(config_path.glob(f"*{ext}"))
+
+ for cf in config_files:
+ if os.path.basename(cf)[0] == ".":
+ continue # Ignore hidden files
+
+ with open(cf, "r") as f:
+ model_cfg = json.load(f)
+ if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
+ _MODEL_CONFIGS[cf.stem] = model_cfg
+
+ _MODEL_CONFIGS = {
+ k: v
+ for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
+ }
+
+
+_rescan_model_configs() # initial populate of model config registry
+
+
+def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ else:
+ state_dict = checkpoint
+ if skip_params:
+ if next(iter(state_dict.items()))[0].startswith("module"):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+ # for k in state_dict:
+ # if k.startswith('transformer'):
+ # v = state_dict.pop(k)
+ # state_dict['text_branch.' + k[12:]] = v
+ return state_dict
+
+
+def create_model(
+ amodel_name: str,
+ tmodel_name: str,
+ pretrained: str = "",
+ precision: str = "fp32",
+ device: torch.device = torch.device("cpu"),
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ openai_model_cache_dir: str = os.path.expanduser(f"{CACHE_DIR}/clip"),
+ skip_params=True,
+ pretrained_audio: str = "",
+ pretrained_text: str = "",
+ enable_fusion: bool = False,
+ fusion_type: str = "None"
+ # pretrained_image: bool = False,
+):
+ amodel_name = amodel_name.replace(
+ "/", "-"
+ ) # for callers using old naming with / in ViT names
+ pretrained_orig = pretrained
+ pretrained = pretrained.lower()
+ if pretrained == "openai":
+ if amodel_name in _MODEL_CONFIGS:
+ logging.info(f"Loading {amodel_name} model config.")
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
+ else:
+ logging.error(
+ f"Model config for {amodel_name} not found; available models {list_models()}."
+ )
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
+
+ logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
+ # Hard Code in model name
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
+ model = load_openai_model(
+ "ViT-B-16",
+ model_cfg,
+ device=device,
+ jit=jit,
+ cache_dir=openai_model_cache_dir,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+ # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
+ if precision == "amp" or precision == "fp32":
+ model = model.float()
+ else:
+ if amodel_name in _MODEL_CONFIGS:
+ logging.info(f"Loading {amodel_name} model config.")
+ model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
+ else:
+ logging.error(
+ f"Model config for {amodel_name} not found; available models {list_models()}."
+ )
+ raise RuntimeError(f"Model config for {amodel_name} not found.")
+
+ if force_quick_gelu:
+ # override for use of QuickGELU on non-OpenAI transformer models
+ model_cfg["quick_gelu"] = True
+
+ # if pretrained_image:
+ # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
+ # # pretrained weight loading for timm models set via vision_cfg
+ # model_cfg['vision_cfg']['timm_model_pretrained'] = True
+ # else:
+ # assert False, 'pretrained image towers currently only supported for timm models'
+ model_cfg["text_cfg"]["model_type"] = tmodel_name
+ model_cfg["enable_fusion"] = enable_fusion
+ model_cfg["fusion_type"] = fusion_type
+ model = CLAP(**model_cfg)
+
+ if pretrained:
+ checkpoint_path = ""
+ url = get_pretrained_url(amodel_name, pretrained)
+ if url:
+ checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
+ elif os.path.exists(pretrained_orig):
+ checkpoint_path = pretrained_orig
+ if checkpoint_path:
+ logging.info(
+ f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
+ )
+ ckpt = load_state_dict(checkpoint_path, skip_params=True)
+ model.load_state_dict(ckpt)
+ param_names = [n for n, p in model.named_parameters()]
+ # for n in param_names:
+ # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
+ else:
+ logging.warning(
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
+ )
+ raise RuntimeError(
+ f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
+ )
+
+ if pretrained_audio:
+ if amodel_name.startswith("PANN"):
+ if "Cnn14_mAP" in pretrained_audio: # official checkpoint
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+ audio_ckpt = audio_ckpt["model"]
+ keys = list(audio_ckpt.keys())
+ for key in keys:
+ if (
+ "spectrogram_extractor" not in key
+ and "logmel_extractor" not in key
+ ):
+ v = audio_ckpt.pop(key)
+ audio_ckpt["audio_branch." + key] = v
+ elif os.path.basename(pretrained_audio).startswith(
+ "PANN"
+ ): # checkpoint trained via HTSAT codebase
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+ audio_ckpt = audio_ckpt["state_dict"]
+ keys = list(audio_ckpt.keys())
+ for key in keys:
+ if key.startswith("sed_model"):
+ v = audio_ckpt.pop(key)
+ audio_ckpt["audio_branch." + key[10:]] = v
+ elif os.path.basename(pretrained_audio).startswith(
+ "finetuned"
+ ): # checkpoint trained via linear probe codebase
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+ else:
+ raise ValueError("Unknown audio checkpoint")
+ elif amodel_name.startswith("HTSAT"):
+ if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+ audio_ckpt = audio_ckpt["state_dict"]
+ keys = list(audio_ckpt.keys())
+ for key in keys:
+ if key.startswith("sed_model") and (
+ "spectrogram_extractor" not in key
+ and "logmel_extractor" not in key
+ ):
+ v = audio_ckpt.pop(key)
+ audio_ckpt["audio_branch." + key[10:]] = v
+ elif os.path.basename(pretrained_audio).startswith(
+ "HTSAT"
+ ): # checkpoint trained via HTSAT codebase
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+ audio_ckpt = audio_ckpt["state_dict"]
+ keys = list(audio_ckpt.keys())
+ for key in keys:
+ if key.startswith("sed_model"):
+ v = audio_ckpt.pop(key)
+ audio_ckpt["audio_branch." + key[10:]] = v
+ elif os.path.basename(pretrained_audio).startswith(
+ "finetuned"
+ ): # checkpoint trained via linear probe codebase
+ audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
+ else:
+ raise ValueError("Unknown audio checkpoint")
+ else:
+ raise f"this audio encoder pretrained checkpoint is not support"
+
+ model.load_state_dict(audio_ckpt, strict=False)
+ logging.info(
+ f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
+ )
+ param_names = [n for n, p in model.named_parameters()]
+ for n in param_names:
+ print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
+
+ model.to(device=device)
+ if precision == "fp16":
+ assert device.type != "cpu"
+ convert_weights_to_fp16(model)
+
+ if jit:
+ model = torch.jit.script(model)
+
+ return model, model_cfg
+
+
+def create_model_and_transforms(
+ model_name: str,
+ pretrained: str = "",
+ precision: str = "fp32",
+ device: torch.device = torch.device("cpu"),
+ jit: bool = False,
+ force_quick_gelu: bool = False,
+ # pretrained_image: bool = False,
+):
+ model = create_model(
+ model_name,
+ pretrained,
+ precision,
+ device,
+ jit,
+ force_quick_gelu=force_quick_gelu,
+ # pretrained_image=pretrained_image
+ )
+ preprocess_train = image_transform(model.visual.image_size, is_train=True)
+ preprocess_val = image_transform(model.visual.image_size, is_train=False)
+ return model, preprocess_train, preprocess_val
+
+
+def list_models():
+ """enumerate available model architectures based on config files"""
+ return list(_MODEL_CONFIGS.keys())
+
+
+def add_model_config(path):
+ """add model config path or file and update registry"""
+ if not isinstance(path, Path):
+ path = Path(path)
+ _MODEL_CONFIG_PATHS.append(path)
+ _rescan_model_configs()
diff --git a/audioldm/clap/open_clip/feature_fusion.py b/audioldm/clap/open_clip/feature_fusion.py
new file mode 100755
index 0000000000000000000000000000000000000000..dbe4e170e05894c12ebdc36ba1dc1de65e441b89
--- /dev/null
+++ b/audioldm/clap/open_clip/feature_fusion.py
@@ -0,0 +1,192 @@
+"""
+Feature Fusion for Varible-Length Data Processing
+AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
+According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
+"""
+
+import torch
+import torch.nn as nn
+
+
+class DAF(nn.Module):
+ """
+ 直接相加 DirectAddFuse
+ """
+
+ def __init__(self):
+ super(DAF, self).__init__()
+
+ def forward(self, x, residual):
+ return x + residual
+
+
+class iAFF(nn.Module):
+ """
+ 多特征融合 iAFF
+ """
+
+ def __init__(self, channels=64, r=4, type="2D"):
+ super(iAFF, self).__init__()
+ inter_channels = int(channels // r)
+
+ if type == "1D":
+ # 本地注意力
+ self.local_att = nn.Sequential(
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+
+ # 全局注意力
+ self.global_att = nn.Sequential(
+ nn.AdaptiveAvgPool1d(1),
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+
+ # 第二次本地注意力
+ self.local_att2 = nn.Sequential(
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+ # 第二次全局注意力
+ self.global_att2 = nn.Sequential(
+ nn.AdaptiveAvgPool1d(1),
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+ elif type == "2D":
+ # 本地注意力
+ self.local_att = nn.Sequential(
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+
+ # 全局注意力
+ self.global_att = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+
+ # 第二次本地注意力
+ self.local_att2 = nn.Sequential(
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+ # 第二次全局注意力
+ self.global_att2 = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+ else:
+ raise f"the type is not supported"
+
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x, residual):
+ flag = False
+ xa = x + residual
+ if xa.size(0) == 1:
+ xa = torch.cat([xa, xa], dim=0)
+ flag = True
+ xl = self.local_att(xa)
+ xg = self.global_att(xa)
+ xlg = xl + xg
+ wei = self.sigmoid(xlg)
+ xi = x * wei + residual * (1 - wei)
+
+ xl2 = self.local_att2(xi)
+ xg2 = self.global_att(xi)
+ xlg2 = xl2 + xg2
+ wei2 = self.sigmoid(xlg2)
+ xo = x * wei2 + residual * (1 - wei2)
+ if flag:
+ xo = xo[0].unsqueeze(0)
+ return xo
+
+
+class AFF(nn.Module):
+ """
+ 多特征融合 AFF
+ """
+
+ def __init__(self, channels=64, r=4, type="2D"):
+ super(AFF, self).__init__()
+ inter_channels = int(channels // r)
+
+ if type == "1D":
+ self.local_att = nn.Sequential(
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+ self.global_att = nn.Sequential(
+ nn.AdaptiveAvgPool1d(1),
+ nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm1d(channels),
+ )
+ elif type == "2D":
+ self.local_att = nn.Sequential(
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+ self.global_att = nn.Sequential(
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(inter_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
+ nn.BatchNorm2d(channels),
+ )
+ else:
+ raise f"the type is not supported."
+
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x, residual):
+ flag = False
+ xa = x + residual
+ if xa.size(0) == 1:
+ xa = torch.cat([xa, xa], dim=0)
+ flag = True
+ xl = self.local_att(xa)
+ xg = self.global_att(xa)
+ xlg = xl + xg
+ wei = self.sigmoid(xlg)
+ xo = 2 * x * wei + 2 * residual * (1 - wei)
+ if flag:
+ xo = xo[0].unsqueeze(0)
+ return xo
diff --git a/audioldm/clap/open_clip/htsat.py b/audioldm/clap/open_clip/htsat.py
new file mode 100755
index 0000000000000000000000000000000000000000..3b856c6a43df162116a941f1b5c76e93713b276a
--- /dev/null
+++ b/audioldm/clap/open_clip/htsat.py
@@ -0,0 +1,1308 @@
+# Ke Chen
+# knutchen@ucsd.edu
+# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
+# Some layers designed on the model
+# below codes are based and referred from https://github.com/microsoft/Swin-Transformer
+# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from itertools import repeat
+import collections.abc
+import math
+import warnings
+
+from torch.nn.init import _calculate_fan_in_and_fan_out
+import torch.utils.checkpoint as checkpoint
+
+import random
+
+from torchlibrosa.stft import Spectrogram, LogmelFilterBank
+from torchlibrosa.augmentation import SpecAugmentation
+
+from itertools import repeat
+from .utils import do_mixup, interpolate
+
+from .feature_fusion import iAFF, AFF, DAF
+
+# from PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+ """
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ norm_layer=None,
+ flatten=True,
+ patch_stride=16,
+ enable_fusion=False,
+ fusion_type="None",
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patch_stride = to_2tuple(patch_stride)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patch_stride = patch_stride
+ self.grid_size = (
+ img_size[0] // patch_stride[0],
+ img_size[1] // patch_stride[1],
+ )
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ padding = (
+ (patch_size[0] - patch_stride[0]) // 2,
+ (patch_size[1] - patch_stride[1]) // 2,
+ )
+
+ if (self.enable_fusion) and (self.fusion_type == "channel_map"):
+ self.proj = nn.Conv2d(
+ in_chans * 4,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=patch_stride,
+ padding=padding,
+ )
+ else:
+ self.proj = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=patch_stride,
+ padding=padding,
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ if (self.enable_fusion) and (
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
+ ):
+ self.mel_conv2d = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=(patch_size[0], patch_size[1] * 3),
+ stride=(patch_stride[0], patch_stride[1] * 3),
+ padding=padding,
+ )
+ if self.fusion_type == "daf_2d":
+ self.fusion_model = DAF()
+ elif self.fusion_type == "aff_2d":
+ self.fusion_model = AFF(channels=embed_dim, type="2D")
+ elif self.fusion_type == "iaff_2d":
+ self.fusion_model = iAFF(channels=embed_dim, type="2D")
+
+ def forward(self, x, longer_idx=None):
+ if (self.enable_fusion) and (
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
+ ):
+ global_x = x[:, 0:1, :, :]
+
+ # global processing
+ B, C, H, W = global_x.shape
+ assert (
+ H == self.img_size[0] and W == self.img_size[1]
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ global_x = self.proj(global_x)
+ TW = global_x.size(-1)
+ if len(longer_idx) > 0:
+ # local processing
+ local_x = x[longer_idx, 1:, :, :].contiguous()
+ B, C, H, W = local_x.shape
+ local_x = local_x.view(B * C, 1, H, W)
+ local_x = self.mel_conv2d(local_x)
+ local_x = local_x.view(
+ B, C, local_x.size(1), local_x.size(2), local_x.size(3)
+ )
+ local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
+ TB, TC, TH, _ = local_x.size()
+ if local_x.size(-1) < TW:
+ local_x = torch.cat(
+ [
+ local_x,
+ torch.zeros(
+ (TB, TC, TH, TW - local_x.size(-1)),
+ device=global_x.device,
+ ),
+ ],
+ dim=-1,
+ )
+ else:
+ local_x = local_x[:, :, :, :TW]
+
+ global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
+ x = global_x
+ else:
+ B, C, H, W = x.shape
+ assert (
+ H == self.img_size[0] and W == self.img_size[1]
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x)
+
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == "fan_in":
+ denom = fan_in
+ elif mode == "fan_out":
+ denom = fan_out
+ elif mode == "fan_avg":
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == "truncated_normal":
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
+ elif distribution == "normal":
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == "uniform":
+ bound = math.sqrt(3 * variance)
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = (
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ )
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(
+ B, H // window_size, W // window_size, window_size, window_size, -1
+ )
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(
+ self,
+ dim,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ qk_scale=None,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ ):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1],
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
+ 1
+ ).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x, attn
+
+ def extra_repr(self):
+ return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
+
+
+# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
+class SwinTransformerBlock(nn.Module):
+ r"""Swin Transformer Block.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(
+ self,
+ dim,
+ input_resolution,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ norm_before_mlp="ln",
+ ):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ self.norm_before_mlp = norm_before_mlp
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert (
+ 0 <= self.shift_size < self.window_size
+ ), "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ if self.norm_before_mlp == "ln":
+ self.norm2 = nn.LayerNorm(dim)
+ elif self.norm_before_mlp == "bn":
+ self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
+ 1, 2
+ )
+ else:
+ raise NotImplementedError
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ w_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(
+ img_mask, self.window_size
+ ) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(
+ attn_mask != 0, float(-100.0)
+ ).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def forward(self, x):
+ # pdb.set_trace()
+ H, W = self.input_resolution
+ # print("H: ", H)
+ # print("W: ", W)
+ # pdb.set_trace()
+ B, L, C = x.shape
+ # assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
+ )
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(
+ shifted_x, self.window_size
+ ) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(
+ -1, self.window_size * self.window_size, C
+ ) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows, attn = self.attn(
+ x_windows, mask=self.attn_mask
+ ) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
+ )
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x, attn
+
+ def extra_repr(self):
+ return (
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+ )
+
+
+class PatchMerging(nn.Module):
+ r"""Patch Merging Layer.
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self):
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+
+class BasicLayer(nn.Module):
+ """A basic Swin Transformer layer for one stage.
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(
+ self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ norm_before_mlp="ln",
+ ):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList(
+ [
+ SwinTransformerBlock(
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i]
+ if isinstance(drop_path, list)
+ else drop_path,
+ norm_layer=norm_layer,
+ norm_before_mlp=norm_before_mlp,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(
+ input_resolution, dim=dim, norm_layer=norm_layer
+ )
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ attns = []
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x, attn = blk(x)
+ if not self.training:
+ attns.append(attn.unsqueeze(0))
+ if self.downsample is not None:
+ x = self.downsample(x)
+ if not self.training:
+ attn = torch.cat(attns, dim=0)
+ attn = torch.mean(attn, dim=0)
+ return x, attn
+
+ def extra_repr(self):
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+
+# The Core of HTSAT
+class HTSAT_Swin_Transformer(nn.Module):
+ r"""HTSAT based on the Swin Transformer
+ Args:
+ spec_size (int | tuple(int)): Input Spectrogram size. Default 256
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
+ in_chans (int): Number of input image channels. Default: 1 (mono)
+ num_classes (int): Number of classes for classification head. Default: 527
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 8
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ config (module): The configuration Module from config.py
+ """
+
+ def __init__(
+ self,
+ spec_size=256,
+ patch_size=4,
+ patch_stride=(4, 4),
+ in_chans=1,
+ num_classes=527,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=8,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False,
+ norm_before_mlp="ln",
+ config=None,
+ enable_fusion=False,
+ fusion_type="None",
+ **kwargs,
+ ):
+ super(HTSAT_Swin_Transformer, self).__init__()
+
+ self.config = config
+ self.spec_size = spec_size
+ self.patch_stride = patch_stride
+ self.patch_size = patch_size
+ self.window_size = window_size
+ self.embed_dim = embed_dim
+ self.depths = depths
+ self.ape = ape
+ self.in_chans = in_chans
+ self.num_classes = num_classes
+ self.num_heads = num_heads
+ self.num_layers = len(self.depths)
+ self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
+
+ self.drop_rate = drop_rate
+ self.attn_drop_rate = attn_drop_rate
+ self.drop_path_rate = drop_path_rate
+
+ self.qkv_bias = qkv_bias
+ self.qk_scale = None
+
+ self.patch_norm = patch_norm
+ self.norm_layer = norm_layer if self.patch_norm else None
+ self.norm_before_mlp = norm_before_mlp
+ self.mlp_ratio = mlp_ratio
+
+ self.use_checkpoint = use_checkpoint
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ # process mel-spec ; used only once
+ self.freq_ratio = self.spec_size // self.config.mel_bins
+ window = "hann"
+ center = True
+ pad_mode = "reflect"
+ ref = 1.0
+ amin = 1e-10
+ top_db = None
+ self.interpolate_ratio = 32 # Downsampled ratio
+ # Spectrogram extractor
+ self.spectrogram_extractor = Spectrogram(
+ n_fft=config.window_size,
+ hop_length=config.hop_size,
+ win_length=config.window_size,
+ window=window,
+ center=center,
+ pad_mode=pad_mode,
+ freeze_parameters=True,
+ )
+ # Logmel feature extractor
+ self.logmel_extractor = LogmelFilterBank(
+ sr=config.sample_rate,
+ n_fft=config.window_size,
+ n_mels=config.mel_bins,
+ fmin=config.fmin,
+ fmax=config.fmax,
+ ref=ref,
+ amin=amin,
+ top_db=top_db,
+ freeze_parameters=True,
+ )
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(
+ time_drop_width=64,
+ time_stripes_num=2,
+ freq_drop_width=8,
+ freq_stripes_num=2,
+ ) # 2 2
+ self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
+
+ # split spctrogram into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=self.spec_size,
+ patch_size=self.patch_size,
+ in_chans=self.in_chans,
+ embed_dim=self.embed_dim,
+ norm_layer=self.norm_layer,
+ patch_stride=patch_stride,
+ enable_fusion=self.enable_fusion,
+ fusion_type=self.fusion_type,
+ )
+
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.grid_size
+ self.patches_resolution = patches_resolution
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches, self.embed_dim)
+ )
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
+
+ self.pos_drop = nn.Dropout(p=self.drop_rate)
+
+ # stochastic depth
+ dpr = [
+ x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
+ ] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(self.embed_dim * 2**i_layer),
+ input_resolution=(
+ patches_resolution[0] // (2**i_layer),
+ patches_resolution[1] // (2**i_layer),
+ ),
+ depth=self.depths[i_layer],
+ num_heads=self.num_heads[i_layer],
+ window_size=self.window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=self.qkv_bias,
+ qk_scale=self.qk_scale,
+ drop=self.drop_rate,
+ attn_drop=self.attn_drop_rate,
+ drop_path=dpr[
+ sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1])
+ ],
+ norm_layer=self.norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint,
+ norm_before_mlp=self.norm_before_mlp,
+ )
+ self.layers.append(layer)
+
+ self.norm = self.norm_layer(self.num_features)
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.maxpool = nn.AdaptiveMaxPool1d(1)
+
+ SF = (
+ self.spec_size
+ // (2 ** (len(self.depths) - 1))
+ // self.patch_stride[0]
+ // self.freq_ratio
+ )
+ self.tscam_conv = nn.Conv2d(
+ in_channels=self.num_features,
+ out_channels=self.num_classes,
+ kernel_size=(SF, 3),
+ padding=(0, 1),
+ )
+ self.head = nn.Linear(num_classes, num_classes)
+
+ if (self.enable_fusion) and (
+ self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
+ ):
+ self.mel_conv1d = nn.Sequential(
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
+ nn.BatchNorm1d(64),
+ )
+ if self.fusion_type == "daf_1d":
+ self.fusion_model = DAF()
+ elif self.fusion_type == "aff_1d":
+ self.fusion_model = AFF(channels=64, type="1D")
+ elif self.fusion_type == "iaff_1d":
+ self.fusion_model = iAFF(channels=64, type="1D")
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"absolute_pos_embed"}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {"relative_position_bias_table"}
+
+ def forward_features(self, x, longer_idx=None):
+ # A deprecated optimization for using a hierarchical output from different blocks
+
+ frames_num = x.shape[2]
+ x = self.patch_embed(x, longer_idx=longer_idx)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+ for i, layer in enumerate(self.layers):
+ x, attn = layer(x)
+ # for x
+ x = self.norm(x)
+ B, N, C = x.shape
+ SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
+ ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
+ x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST)
+ B, C, F, T = x.shape
+ # group 2D CNN
+ c_freq_bin = F // self.freq_ratio
+ x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
+ x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)
+ # get latent_output
+ fine_grained_latent_output = torch.mean(x, dim=2)
+ fine_grained_latent_output = interpolate(
+ fine_grained_latent_output.permute(0, 2, 1).contiguous(),
+ 8 * self.patch_stride[1],
+ )
+
+ latent_output = self.avgpool(torch.flatten(x, 2))
+ latent_output = torch.flatten(latent_output, 1)
+
+ # display the attention map, if needed
+
+ x = self.tscam_conv(x)
+ x = torch.flatten(x, 2) # B, C, T
+
+ fpx = interpolate(
+ torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]
+ )
+
+ x = self.avgpool(x)
+ x = torch.flatten(x, 1)
+
+ output_dict = {
+ "framewise_output": fpx, # already sigmoided
+ "clipwise_output": torch.sigmoid(x),
+ "fine_grained_embedding": fine_grained_latent_output,
+ "embedding": latent_output,
+ }
+
+ return output_dict
+
+ def crop_wav(self, x, crop_size, spe_pos=None):
+ time_steps = x.shape[2]
+ tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
+ for i in range(len(x)):
+ if spe_pos is None:
+ crop_pos = random.randint(0, time_steps - crop_size - 1)
+ else:
+ crop_pos = spe_pos
+ tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :]
+ return tx
+
+ # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
+ def reshape_wav2img(self, x):
+ B, C, T, F = x.shape
+ target_T = int(self.spec_size * self.freq_ratio)
+ target_F = self.spec_size // self.freq_ratio
+ assert (
+ T <= target_T and F <= target_F
+ ), "the wav size should less than or equal to the swin input size"
+ # to avoid bicubic zero error
+ if T < target_T:
+ x = nn.functional.interpolate(
+ x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
+ )
+ if F < target_F:
+ x = nn.functional.interpolate(
+ x, (x.shape[2], target_F), mode="bicubic", align_corners=True
+ )
+ x = x.permute(0, 1, 3, 2).contiguous()
+ x = x.reshape(
+ x.shape[0],
+ x.shape[1],
+ x.shape[2],
+ self.freq_ratio,
+ x.shape[3] // self.freq_ratio,
+ )
+ # print(x.shape)
+ x = x.permute(0, 1, 3, 2, 4).contiguous()
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
+ return x
+
+ # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
+ def repeat_wat2img(self, x, cur_pos):
+ B, C, T, F = x.shape
+ target_T = int(self.spec_size * self.freq_ratio)
+ target_F = self.spec_size // self.freq_ratio
+ assert (
+ T <= target_T and F <= target_F
+ ), "the wav size should less than or equal to the swin input size"
+ # to avoid bicubic zero error
+ if T < target_T:
+ x = nn.functional.interpolate(
+ x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
+ )
+ if F < target_F:
+ x = nn.functional.interpolate(
+ x, (x.shape[2], target_F), mode="bicubic", align_corners=True
+ )
+ x = x.permute(0, 1, 3, 2).contiguous() # B C F T
+ x = x[:, :, :, cur_pos : cur_pos + self.spec_size]
+ x = x.repeat(repeats=(1, 1, 4, 1))
+ return x
+
+ def forward(
+ self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None
+ ): # out_feat_keys: List[str] = None):
+
+ if self.enable_fusion and x["longer"].sum() == 0:
+ # if no audio is longer than 10s, then randomly select one audio to be longer
+ x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
+
+ if not self.enable_fusion:
+ x = x["waveform"].to(device=device, non_blocking=True)
+ x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ if self.training:
+ x = self.spec_augmenter(x)
+
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+
+ x = self.reshape_wav2img(x)
+ output_dict = self.forward_features(x)
+ else:
+ longer_list = x["longer"].to(device=device, non_blocking=True)
+ x = x["mel_fusion"].to(device=device, non_blocking=True)
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ longer_list_idx = torch.where(longer_list)[0]
+ if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
+ new_x = x[:, 0:1, :, :].clone().contiguous()
+ if len(longer_list_idx) > 0:
+ # local processing
+ fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
+ FB, FC, FT, FF = fusion_x_local.size()
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
+ fusion_x_local = torch.permute(
+ fusion_x_local, (0, 2, 1)
+ ).contiguous()
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
+ fusion_x_local = fusion_x_local.view(
+ FB, FC, FF, fusion_x_local.size(-1)
+ )
+ fusion_x_local = (
+ torch.permute(fusion_x_local, (0, 2, 1, 3))
+ .contiguous()
+ .flatten(2)
+ )
+ if fusion_x_local.size(-1) < FT:
+ fusion_x_local = torch.cat(
+ [
+ fusion_x_local,
+ torch.zeros(
+ (FB, FF, FT - fusion_x_local.size(-1)),
+ device=device,
+ ),
+ ],
+ dim=-1,
+ )
+ else:
+ fusion_x_local = fusion_x_local[:, :, :FT]
+ # 1D fusion
+ new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
+ new_x[longer_list_idx] = self.fusion_model(
+ new_x[longer_list_idx], fusion_x_local
+ )
+ x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
+ else:
+ x = new_x
+
+ elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
+ x = x # no change
+
+ if self.training:
+ x = self.spec_augmenter(x)
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+
+ x = self.reshape_wav2img(x)
+ output_dict = self.forward_features(x, longer_idx=longer_list_idx)
+
+ # if infer_mode:
+ # # in infer mode. we need to handle different length audio input
+ # frame_num = x.shape[2]
+ # target_T = int(self.spec_size * self.freq_ratio)
+ # repeat_ratio = math.floor(target_T / frame_num)
+ # x = x.repeat(repeats=(1,1,repeat_ratio,1))
+ # x = self.reshape_wav2img(x)
+ # output_dict = self.forward_features(x)
+ # else:
+ # if x.shape[2] > self.freq_ratio * self.spec_size:
+ # if self.training:
+ # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
+ # x = self.reshape_wav2img(x)
+ # output_dict = self.forward_features(x)
+ # else:
+ # # Change: Hard code here
+ # overlap_size = (x.shape[2] - 1) // 4
+ # output_dicts = []
+ # crop_size = (x.shape[2] - 1) // 2
+ # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
+ # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
+ # tx = self.reshape_wav2img(tx)
+ # output_dicts.append(self.forward_features(tx))
+ # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
+ # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
+ # for d in output_dicts:
+ # clipwise_output += d["clipwise_output"]
+ # framewise_output += d["framewise_output"]
+ # clipwise_output = clipwise_output / len(output_dicts)
+ # framewise_output = framewise_output / len(output_dicts)
+ # output_dict = {
+ # 'framewise_output': framewise_output,
+ # 'clipwise_output': clipwise_output
+ # }
+ # else: # this part is typically used, and most easy one
+ # x = self.reshape_wav2img(x)
+ # output_dict = self.forward_features(x)
+ # x = self.head(x)
+
+ # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
+
+ return output_dict
+
+
+def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"):
+ try:
+
+ assert audio_cfg.model_name in [
+ "tiny",
+ "base",
+ "large",
+ ], "model name for HTS-AT is wrong!"
+ if audio_cfg.model_name == "tiny":
+ model = HTSAT_Swin_Transformer(
+ spec_size=256,
+ patch_size=4,
+ patch_stride=(4, 4),
+ num_classes=audio_cfg.class_num,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=8,
+ config=audio_cfg,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+ elif audio_cfg.model_name == "base":
+ model = HTSAT_Swin_Transformer(
+ spec_size=256,
+ patch_size=4,
+ patch_stride=(4, 4),
+ num_classes=audio_cfg.class_num,
+ embed_dim=128,
+ depths=[2, 2, 12, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=8,
+ config=audio_cfg,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+ elif audio_cfg.model_name == "large":
+ model = HTSAT_Swin_Transformer(
+ spec_size=256,
+ patch_size=4,
+ patch_stride=(4, 4),
+ num_classes=audio_cfg.class_num,
+ embed_dim=256,
+ depths=[2, 2, 12, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=8,
+ config=audio_cfg,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+
+ return model
+ except:
+ raise RuntimeError(
+ f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
+ )
diff --git a/audioldm/clap/open_clip/linear_probe.py b/audioldm/clap/open_clip/linear_probe.py
new file mode 100755
index 0000000000000000000000000000000000000000..9d7e23b6b67a53e16d050d675a99d01d7d04d581
--- /dev/null
+++ b/audioldm/clap/open_clip/linear_probe.py
@@ -0,0 +1,66 @@
+import numpy as np
+import torch.nn.functional as F
+from torch import nn
+from .model import MLPLayers
+
+
+class LinearProbe(nn.Module):
+ def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None):
+ """
+ Args:
+ model: nn.Module
+ mlp: bool, if True, then use the MLP layer as the linear probe module
+ freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe
+ in_ch: int, the output channel from CLAP model
+ out_ch: int, the output channel from linear probe (class_num)
+ act: torch.nn.functional, the activation function before the loss function
+ """
+ super().__init__()
+ in_ch = 512
+ self.clap_model = model
+ self.clap_model.text_branch = None # to save memory
+ self.freeze = freeze
+ if mlp:
+ self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch])
+ else:
+ self.lp_layer = nn.Linear(in_ch, out_ch)
+
+ if self.freeze:
+ for param in self.clap_model.parameters():
+ param.requires_grad = False
+
+ if act == "None":
+ self.act = None
+ elif act == "relu":
+ self.act = nn.ReLU()
+ elif act == "elu":
+ self.act = nn.ELU()
+ elif act == "prelu":
+ self.act = nn.PReLU(num_parameters=in_ch)
+ elif act == "softmax":
+ self.act = nn.Softmax(dim=-1)
+ elif act == "sigmoid":
+ self.act = nn.Sigmoid()
+
+ def forward(self, x, mix_lambda=None, device=None):
+ """
+ Args:
+ x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list
+ mix_lambda: torch.tensor [batch], the mixup lambda
+ Returns:
+ class_prob: torch.tensor [batch, class_num]
+
+ """
+ # batchnorm cancel grandient
+ if self.freeze:
+ self.clap_model.eval()
+
+ x = self.clap_model.audio_projection(
+ self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[
+ "embedding"
+ ]
+ )
+ out = self.lp_layer(x)
+ if self.act is not None:
+ out = self.act(out)
+ return out
diff --git a/audioldm/clap/open_clip/loss.py b/audioldm/clap/open_clip/loss.py
new file mode 100755
index 0000000000000000000000000000000000000000..cc66298a14997da4aa2efc71e37c0a6bcda53fd1
--- /dev/null
+++ b/audioldm/clap/open_clip/loss.py
@@ -0,0 +1,398 @@
+from multiprocessing.sharedctypes import Value
+import torch
+import torch.distributed.nn
+from torch import distributed as dist, nn as nn
+from torch.nn import functional as F
+import numpy as np
+from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+
+def gather_features(
+ audio_features,
+ text_features,
+ audio_features_mlp=None,
+ text_features_mlp=None,
+ local_loss=False,
+ gather_with_grad=False,
+ rank=0,
+ world_size=1,
+ use_horovod=False,
+ mlp_loss=False,
+):
+ if use_horovod:
+ assert hvd is not None, "Please install horovod"
+ if gather_with_grad:
+ all_audio_features = hvd.allgather(audio_features)
+ all_text_features = hvd.allgather(text_features)
+ if mlp_loss:
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
+ else:
+ with torch.no_grad():
+ all_audio_features = hvd.allgather(audio_features)
+ all_text_features = hvd.allgather(text_features)
+ if mlp_loss:
+ all_audio_features_mlp = hvd.allgather(audio_features_mlp)
+ all_text_features_mlp = hvd.allgather(text_features_mlp)
+ if not local_loss:
+ # ensure grads for local rank when all_* features don't have a gradient
+ gathered_audio_features = list(
+ all_audio_features.chunk(world_size, dim=0)
+ )
+ gathered_text_features = list(
+ all_text_features.chunk(world_size, dim=0)
+ )
+ gathered_audio_features[rank] = audio_features
+ gathered_text_features[rank] = text_features
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
+ all_text_features = torch.cat(gathered_text_features, dim=0)
+ if mlp_loss:
+ gathered_audio_features_mlp = list(
+ all_audio_features_mlp.chunk(world_size, dim=0)
+ )
+ gathered_text_features_mlp = list(
+ all_text_features_mlp.chunk(world_size, dim=0)
+ )
+ gathered_audio_features_mlp[rank] = audio_features_mlp
+ gathered_text_features_mlp[rank] = text_features_mlp
+ all_audio_features_mlp = torch.cat(
+ gathered_audio_features_mlp, dim=0
+ )
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
+ else:
+ # We gather tensors from all gpus
+ if gather_with_grad:
+ all_audio_features = torch.cat(
+ torch.distributed.nn.all_gather(audio_features), dim=0
+ )
+ all_text_features = torch.cat(
+ torch.distributed.nn.all_gather(text_features), dim=0
+ )
+ if mlp_loss:
+ all_audio_features_mlp = torch.cat(
+ torch.distributed.nn.all_gather(audio_features_mlp), dim=0
+ )
+ all_text_features_mlp = torch.cat(
+ torch.distributed.nn.all_gather(text_features_mlp), dim=0
+ )
+ else:
+ gathered_audio_features = [
+ torch.zeros_like(audio_features) for _ in range(world_size)
+ ]
+ gathered_text_features = [
+ torch.zeros_like(text_features) for _ in range(world_size)
+ ]
+ dist.all_gather(gathered_audio_features, audio_features)
+ dist.all_gather(gathered_text_features, text_features)
+ if mlp_loss:
+ gathered_audio_features_mlp = [
+ torch.zeros_like(audio_features_mlp) for _ in range(world_size)
+ ]
+ gathered_text_features_mlp = [
+ torch.zeros_like(text_features_mlp) for _ in range(world_size)
+ ]
+ dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
+ dist.all_gather(gathered_text_features_mlp, text_features_mlp)
+ if not local_loss:
+ # ensure grads for local rank when all_* features don't have a gradient
+ gathered_audio_features[rank] = audio_features
+ gathered_text_features[rank] = text_features
+ if mlp_loss:
+ gathered_audio_features_mlp[rank] = audio_features_mlp
+ gathered_text_features_mlp[rank] = text_features_mlp
+
+ all_audio_features = torch.cat(gathered_audio_features, dim=0)
+ all_text_features = torch.cat(gathered_text_features, dim=0)
+ if mlp_loss:
+ all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
+ all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
+ if mlp_loss:
+ return (
+ all_audio_features,
+ all_text_features,
+ all_audio_features_mlp,
+ all_text_features_mlp,
+ )
+ else:
+ return all_audio_features, all_text_features
+
+
+class ClipLoss(nn.Module):
+ def __init__(
+ self,
+ local_loss=False,
+ gather_with_grad=False,
+ cache_labels=False,
+ rank=0,
+ world_size=1,
+ use_horovod=False,
+ mlp_loss=False,
+ weight_loss_kappa=0,
+ ):
+ super().__init__()
+ self.local_loss = local_loss
+ self.gather_with_grad = gather_with_grad
+ self.cache_labels = cache_labels
+ self.rank = rank
+ self.world_size = world_size
+ self.use_horovod = use_horovod
+ self.mlp_loss = mlp_loss
+ self.weighted_loss = bool(weight_loss_kappa != 0)
+ self.weight_loss_kappa = weight_loss_kappa
+ # cache state
+ self.prev_num_logits = 0
+ self.labels = {}
+
+ def forward(
+ self,
+ audio_features,
+ text_features,
+ logit_scale_a,
+ logit_scale_t=None,
+ audio_features_mlp=None,
+ text_features_mlp=None,
+ ):
+ device = audio_features.device
+ if self.mlp_loss:
+ if self.world_size > 1:
+ (
+ all_audio_features,
+ all_text_features,
+ all_audio_features_mlp,
+ all_text_features_mlp,
+ ) = gather_features(
+ audio_features=audio_features,
+ text_features=text_features,
+ audio_features_mlp=audio_features_mlp,
+ text_features_mlp=text_features_mlp,
+ local_loss=self.local_loss,
+ gather_with_grad=self.gather_with_grad,
+ rank=self.rank,
+ world_size=self.world_size,
+ use_horovod=self.use_horovod,
+ mlp_loss=self.mlp_loss,
+ )
+ if self.local_loss:
+ a_logits_per_audio = (
+ logit_scale_a * audio_features @ all_text_features_mlp.T
+ )
+ a_logits_per_text = (
+ logit_scale_a * text_features_mlp @ all_audio_features.T
+ )
+ t_logits_per_audio = (
+ logit_scale_t * audio_features_mlp @ all_text_features.T
+ )
+ t_logits_per_text = (
+ logit_scale_t * text_features @ all_audio_features_mlp.T
+ )
+ else:
+ a_logits_per_audio = (
+ logit_scale_a * all_audio_features @ all_text_features_mlp.T
+ )
+ a_logits_per_text = a_logits_per_audio.T
+ t_logits_per_audio = (
+ logit_scale_t * all_audio_features_mlp @ all_text_features.T
+ )
+ t_logits_per_text = t_logits_per_audio.T
+ else:
+ a_logits_per_audio = (
+ logit_scale_a * audio_features @ text_features_mlp.T
+ )
+ a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
+ t_logits_per_audio = (
+ logit_scale_t * audio_features_mlp @ text_features.T
+ )
+ t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
+
+ # calculated ground-truth and cache if enabled
+ num_logits = a_logits_per_audio.shape[0]
+ if self.prev_num_logits != num_logits or device not in self.labels:
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
+ if self.world_size > 1 and self.local_loss:
+ labels = labels + num_logits * self.rank
+ if self.cache_labels:
+ self.labels[device] = labels
+ self.prev_num_logits = num_logits
+ else:
+ labels = self.labels[device]
+
+ if not self.weighted_loss:
+ total_loss = (
+ F.cross_entropy(a_logits_per_audio, labels)
+ + F.cross_entropy(a_logits_per_text, labels)
+ + F.cross_entropy(t_logits_per_audio, labels)
+ + F.cross_entropy(t_logits_per_text, labels)
+ ) / 4
+ else:
+ audio_weight = (audio_features @ audio_features.T).detach()
+ audio_weight = (
+ torch.exp(
+ torch.sum(audio_weight, axis=1)
+ / (self.weight_loss_kappa * len(audio_weight))
+ )
+ ).detach()
+ text_weight = (text_features @ text_features.T).detach()
+ text_weight = (
+ torch.exp(
+ torch.sum(text_weight, axis=1)
+ / (self.weight_loss_kappa * len(text_features))
+ )
+ ).detach()
+ total_loss = (
+ F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
+ + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
+ + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
+ + F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
+ ) / 4
+ else:
+ if self.world_size > 1:
+ all_audio_features, all_text_features = gather_features(
+ audio_features=audio_features,
+ text_features=text_features,
+ local_loss=self.local_loss,
+ gather_with_grad=self.gather_with_grad,
+ rank=self.rank,
+ world_size=self.world_size,
+ use_horovod=self.use_horovod,
+ mlp_loss=self.mlp_loss,
+ )
+
+ if self.local_loss:
+ logits_per_audio = (
+ logit_scale_a * audio_features @ all_text_features.T
+ )
+ logits_per_text = (
+ logit_scale_a * text_features @ all_audio_features.T
+ )
+ else:
+ logits_per_audio = (
+ logit_scale_a * all_audio_features @ all_text_features.T
+ )
+ logits_per_text = logits_per_audio.T
+ else:
+ logits_per_audio = logit_scale_a * audio_features @ text_features.T
+ logits_per_text = logit_scale_a * text_features @ audio_features.T
+
+ # calculated ground-truth and cache if enabled
+ num_logits = logits_per_audio.shape[0]
+ if self.prev_num_logits != num_logits or device not in self.labels:
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
+ if self.world_size > 1 and self.local_loss:
+ labels = labels + num_logits * self.rank
+ if self.cache_labels:
+ self.labels[device] = labels
+ self.prev_num_logits = num_logits
+ else:
+ labels = self.labels[device]
+ if not self.weighted_loss:
+ total_loss = (
+ F.cross_entropy(logits_per_audio, labels)
+ + F.cross_entropy(logits_per_text, labels)
+ ) / 2
+ else:
+ audio_weight = (all_audio_features @ all_audio_features.T).detach()
+ audio_weight = (
+ torch.exp(
+ torch.sum(audio_weight, axis=1)
+ / (self.weight_loss_kappa * len(all_audio_features))
+ )
+ ).detach()
+ text_weight = (all_text_features @ all_text_features.T).detach()
+ text_weight = (
+ torch.exp(
+ torch.sum(text_weight, axis=1)
+ / (self.weight_loss_kappa * len(all_text_features))
+ )
+ ).detach()
+ total_loss = (
+ F.cross_entropy(logits_per_audio, labels, weight=text_weight)
+ + F.cross_entropy(logits_per_text, labels, weight=audio_weight)
+ ) / 2
+ return total_loss
+
+
+def lp_gather_features(pred, target, world_size=1, use_horovod=False):
+ if use_horovod:
+ assert hvd is not None, "Please install horovod"
+ with torch.no_grad():
+ all_preds = hvd.allgather(pred)
+ all_targets = hvd.allgath(target)
+ else:
+ gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
+ gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
+
+ dist.all_gather(gathered_preds, pred)
+ dist.all_gather(gathered_targets, target)
+ all_preds = torch.cat(gathered_preds, dim=0)
+ all_targets = torch.cat(gathered_targets, dim=0)
+
+ return all_preds, all_targets
+
+
+def get_map(pred, target):
+ pred = torch.sigmoid(pred).numpy()
+ target = target.numpy()
+ return np.mean(average_precision_score(target, pred, average=None))
+
+
+def get_acc(pred, target):
+ pred = torch.argmax(pred, 1).numpy()
+ target = torch.argmax(target, 1).numpy()
+ return accuracy_score(target, pred)
+
+
+def get_mauc(pred, target):
+ pred = torch.sigmoid(pred).numpy()
+ target = target.numpy()
+ return np.mean(roc_auc_score(target, pred, average=None))
+
+
+class LPMetrics(object):
+ def __init__(self, metric_names=["map", "acc", "mauc"]):
+ self.metrics = []
+ for name in metric_names:
+ self.metrics.append(self.get_metric(name))
+ self.metric_names = metric_names
+
+ def get_metric(self, name):
+ if name == "map":
+ return get_map
+ elif name == "acc":
+ return get_acc
+ elif name == "mauc":
+ return get_mauc
+ else:
+ raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
+
+ def evaluate_mertics(self, pred, target):
+ metric_dict = {}
+ for i in range(len(self.metric_names)):
+ metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
+ return metric_dict
+
+
+def calc_celoss(pred, target):
+ target = torch.argmax(target, 1).long()
+ return nn.CrossEntropyLoss()(pred, target)
+
+
+class LPLoss(nn.Module):
+ def __init__(self, loss_name):
+ super().__init__()
+ if loss_name == "bce":
+ self.loss_func = nn.BCEWithLogitsLoss()
+ elif loss_name == "ce":
+ self.loss_func = calc_celoss
+ elif loss_name == "mse":
+ self.loss_func = nn.MSELoss()
+ else:
+ raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
+
+ def forward(self, pred, target):
+ loss = self.loss_func(pred, target)
+ return loss
diff --git a/audioldm/clap/open_clip/model.py b/audioldm/clap/open_clip/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..b439244f8c293a0b4263b7ac1fd553e9d0adf184
--- /dev/null
+++ b/audioldm/clap/open_clip/model.py
@@ -0,0 +1,936 @@
+""" CLAP Model
+
+Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+Adapted to the Audio Task.
+"""
+
+from collections import OrderedDict
+from dataclasses import dataclass
+from email.mime import audio
+from typing import Tuple, Union, Callable, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .timm_model import TimmModel
+import logging
+from .utils import freeze_batch_norm_2d
+
+from .pann_model import create_pann_model
+from .htsat import create_htsat_model
+from transformers import BertModel, RobertaModel, BartModel
+from transformers.tokenization_utils_base import BatchEncoding
+
+
+class MLPLayers(nn.Module):
+ def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
+ super(MLPLayers, self).__init__()
+ self.nonlin = nonlin
+ self.dropout = dropout
+
+ sequence = []
+ for u0, u1 in zip(units[:-1], units[1:]):
+ sequence.append(nn.Linear(u0, u1))
+ sequence.append(self.nonlin)
+ sequence.append(nn.Dropout(self.dropout))
+ sequence = sequence[:-2]
+
+ self.sequential = nn.Sequential(*sequence)
+
+ def forward(self, X):
+ X = self.sequential(X)
+ return X
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(
+ OrderedDict(
+ [
+ ("-1", nn.AvgPool2d(stride)),
+ (
+ "0",
+ nn.Conv2d(
+ inplanes,
+ planes * self.expansion,
+ 1,
+ stride=1,
+ bias=False,
+ ),
+ ),
+ ("1", nn.BatchNorm2d(planes * self.expansion)),
+ ]
+ )
+ )
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
+ )
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
+ 2, 0, 1
+ ) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x,
+ key=x,
+ value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat(
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
+ ),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False,
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.image_size = image_size
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
+ )
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.conv2 = nn.Conv2d(
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False
+ )
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.avgpool = nn.AvgPool2d(2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
+
+ self.init_parameters()
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def init_parameters(self):
+ if self.attnpool is not None:
+ std = self.attnpool.c_proj.in_features**-0.5
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert (
+ unlocked_groups == 0
+ ), "partial locking not currently supported for this model"
+ for param in self.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self)
+
+ def stem(self, x):
+ for conv, bn in [
+ (self.conv1, self.bn1),
+ (self.conv2, self.bn2),
+ (self.conv3, self.bn3),
+ ]:
+ x = self.relu(bn(conv(x)))
+ x = self.avgpool(x)
+ return x
+
+ def forward(self, x):
+ x = self.stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.attnpool(x)
+
+ return x
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ return x.to(orig_type)
+
+
+class QuickGELU(nn.Module):
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(
+ OrderedDict(
+ [
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", act_layer()),
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
+ ]
+ )
+ )
+ self.ln_2 = LayerNorm(d_model)
+
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.ModuleList(
+ [
+ ResidualAttentionBlock(width, heads, act_layer=act_layer)
+ for _ in range(layers)
+ ]
+ )
+
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
+ for r in self.resblocks:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+
+class VisualTransformer(nn.Module):
+ def __init__(
+ self,
+ image_size: int,
+ patch_size: int,
+ width: int,
+ layers: int,
+ heads: int,
+ output_dim: int,
+ act_layer: Callable = nn.GELU,
+ ):
+ super().__init__()
+ self.image_size = image_size
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(
+ in_channels=3,
+ out_channels=width,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=False,
+ )
+
+ scale = width**-0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(
+ scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
+ )
+ self.ln_pre = LayerNorm(width)
+
+ self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ assert (
+ unlocked_groups == 0
+ ), "partial locking not currently supported for this model"
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, x: torch.Tensor):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat(
+ [
+ self.class_embedding.to(x.dtype)
+ + torch.zeros(
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
+ ),
+ x,
+ ],
+ dim=1,
+ ) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_branch(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_post(x[:, 0, :])
+
+ if self.proj is not None:
+ x = x @ self.proj
+
+ return x
+
+
+@dataclass
+class CLAPVisionCfg:
+ layers: Union[Tuple[int, int, int, int], int] = 12
+ width: int = 768
+ patch_size: int = 16
+ image_size: Union[Tuple[int, int], int] = 224
+ timm_model_name: str = (
+ None # a valid model name overrides layers, width, patch_size
+ )
+ timm_model_pretrained: bool = (
+ False # use (imagenet) pretrained weights for named model
+ )
+ timm_pool: str = (
+ "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
+ )
+ timm_proj: str = (
+ "linear" # linear projection for timm model output ('linear', 'mlp', '')
+ )
+
+
+# Audio Config Class
+@dataclass
+class CLAPAudioCfp:
+ model_type: str = "PANN"
+ model_name: str = "Cnn14"
+ sample_rate: int = 48000
+ # Param
+ audio_length: int = 1024
+ window_size: int = 1024
+ hop_size: int = 1024
+ fmin: int = 50
+ fmax: int = 14000
+ class_num: int = 527
+ mel_bins: int = 64
+ clip_samples: int = 480000
+
+
+@dataclass
+class CLAPTextCfg:
+ context_length: int
+ vocab_size: int
+ width: int
+ heads: int
+ layers: int
+ model_type: str
+
+
+class CLAP(nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ audio_cfg: CLAPAudioCfp,
+ text_cfg: CLAPTextCfg,
+ quick_gelu: bool = False,
+ enable_fusion: bool = False,
+ fusion_type: str = "None",
+ joint_embed_shape: int = 512,
+ mlp_act: str = "relu",
+ ):
+ super().__init__()
+ if isinstance(audio_cfg, dict):
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
+ if isinstance(text_cfg, dict):
+ text_cfg = CLAPTextCfg(**text_cfg)
+
+ self.audio_cfg = audio_cfg
+ self.text_cfg = text_cfg
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+ self.joint_embed_shape = joint_embed_shape
+ self.mlp_act = mlp_act
+
+ self.context_length = text_cfg.context_length
+
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
+ # memory efficient in recent PyTorch releases (>= 1.10).
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
+ act_layer = QuickGELU if quick_gelu else nn.GELU
+
+ if mlp_act == "relu":
+ mlp_act_layer = nn.ReLU()
+ elif mlp_act == "gelu":
+ mlp_act_layer = nn.GELU()
+ else:
+ raise NotImplementedError
+
+ # audio branch
+ # audio branch parameters
+ if audio_cfg.model_type == "PANN":
+ self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
+ elif audio_cfg.model_type == "HTSAT":
+ self.audio_branch = create_htsat_model(
+ audio_cfg, enable_fusion, fusion_type
+ )
+ else:
+ logging.error(f"Model config for {audio_cfg.model_type} not found")
+ raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
+
+ # text branch
+ # text branch parameters
+ if text_cfg.model_type == "transformer":
+ self.text_branch = Transformer(
+ width=text_cfg.width,
+ layers=text_cfg.layers,
+ heads=text_cfg.heads,
+ act_layer=act_layer,
+ )
+ self.vocab_size = text_cfg.vocab_size
+ self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
+ self.positional_embedding = nn.Parameter(
+ torch.empty(self.context_length, text_cfg.width)
+ )
+ self.ln_final = LayerNorm(text_cfg.width)
+ self.text_transform = MLPLayers(
+ units=[
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ ],
+ dropout=0.1,
+ )
+ self.text_projection = nn.Sequential(
+ nn.Linear(text_cfg.width, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+ )
+ elif text_cfg.model_type == "bert":
+ self.text_branch = BertModel.from_pretrained("bert-base-uncased")
+ self.text_transform = MLPLayers(
+ units=[
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ ],
+ dropout=0.1,
+ )
+ self.text_projection = nn.Sequential(
+ nn.Linear(768, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+ )
+ elif text_cfg.model_type == "roberta":
+ self.text_branch = RobertaModel.from_pretrained("roberta-base")
+ self.text_transform = MLPLayers(
+ units=[
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ ],
+ dropout=0.1,
+ )
+ self.text_projection = nn.Sequential(
+ nn.Linear(768, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+ )
+ elif text_cfg.model_type == "bart":
+ self.text_branch = BartModel.from_pretrained("facebook/bart-base")
+ self.text_transform = MLPLayers(
+ units=[
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ ],
+ dropout=0.1,
+ )
+ self.text_projection = nn.Sequential(
+ nn.Linear(768, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+ )
+ else:
+ logging.error(f"Model config for {text_cfg.model_type} not found")
+ raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
+ self.text_branch_type = text_cfg.model_type
+ # text branch parameters
+
+ # audio branch parameters
+ self.audio_transform = MLPLayers(
+ units=[
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ self.joint_embed_shape,
+ ],
+ dropout=0.1,
+ )
+
+ # below here is text branch parameters
+
+ # ============================================================================================================
+ self.audio_projection = nn.Sequential(
+ nn.Linear(embed_dim, self.joint_embed_shape),
+ mlp_act_layer,
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
+ )
+
+ self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
+
+ self.init_text_branch_parameters()
+
+ def init_text_branch_parameters(self):
+ if self.text_branch_type == "transformer":
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+ proj_std = (self.text_branch.width**-0.5) * (
+ (2 * self.text_branch.layers) ** -0.5
+ )
+ attn_std = self.text_branch.width**-0.5
+ fc_std = (2 * self.text_branch.width) ** -0.5
+ for block in self.text_branch.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+ if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
+ width = self.text_branch.embeddings.word_embeddings.weight.shape[-1]
+ elif self.text_branch_type == "bart":
+ width = self.text_branch.shared.weight.shape[-1]
+ else:
+ width = self.text_branch.width
+ nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
+ nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
+
+ # deprecated
+ # if hasattr(self.visual, 'init_parameters'):
+ # self.visual.init_parameters()
+
+ # if self.text_projection is not None:
+ # nn.init.normal_(self.text_projection, std=width**-0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ def encode_audio(self, audio, device):
+ return self.audio_branch(
+ audio, mixup_lambda=None, device=device
+ ) # mix lambda needs to add
+
+ # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
+ # tmp = {}
+ # for k in x[0].keys():
+ # tmp[k] = []
+ # for i in range(len(x)):
+ # tmp[k].append(x[i][k][:77])
+ # for k in x[0].keys():
+ # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
+ # return tmp
+
+ def encode_text(self, text, device):
+ if self.text_branch_type == "transformer":
+ text = text.to(device=device, non_blocking=True)
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_branch(x, attn_mask=self.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
+ elif self.text_branch_type == "bert":
+ # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
+ # text = BatchEncoding(text)
+ x = self.text_branch(
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
+ attention_mask=text["attention_mask"].to(
+ device=device, non_blocking=True
+ ),
+ token_type_ids=text["token_type_ids"].to(
+ device=device, non_blocking=True
+ ),
+ )["pooler_output"]
+ x = self.text_projection(x)
+ elif self.text_branch_type == "roberta":
+ x = self.text_branch(
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
+ attention_mask=text["attention_mask"].to(
+ device=device, non_blocking=True
+ ),
+ )["pooler_output"]
+ x = self.text_projection(x)
+ elif self.text_branch_type == "bart":
+ x = torch.mean(
+ self.text_branch(
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
+ attention_mask=text["attention_mask"].to(
+ device=device, non_blocking=True
+ ),
+ )["encoder_last_hidden_state"],
+ axis=1,
+ )
+ x = self.text_projection(x)
+ else:
+ logging.error(f"Model type {self.text_branch_type} not found")
+ raise RuntimeError(f"Model type {self.text_branch_type} not found.")
+ return x
+
+ def forward(self, audio, text, device=None):
+ """Forward audio and text into the CLAP
+
+ Parameters
+ ----------
+ audio: torch.Tensor (batch_size, audio_length)
+ the time-domain audio input / the batch of mel_spec and longer list.
+ text: torch.Tensor () // need to add
+ the text token input
+ """
+ if device is None:
+ if audio is not None:
+ device = audio.device
+ elif text is not None:
+ device = text.device
+ if audio is None and text is None:
+ # a hack to get the logit scale
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
+ elif audio is None:
+ return self.encode_text(text, device=device)
+ elif text is None:
+ return self.audio_projection(
+ self.encode_audio(audio, device=device)["embedding"]
+ )
+ audio_features = self.audio_projection(
+ self.encode_audio(audio, device=device)["embedding"]
+ )
+ audio_features = F.normalize(audio_features, dim=-1)
+
+ text_features = self.encode_text(text, device=device)
+ # print("text_features", text_features)
+ # print("text_features.shape", text_features.shape)
+ # print("text_features.type", type(text_features))
+ text_features = F.normalize(text_features, dim=-1)
+
+ audio_features_mlp = self.audio_transform(audio_features)
+ text_features_mlp = self.text_transform(text_features)
+ # Four outputs: audio features (basic & MLP), text features (basic & MLP)
+ return (
+ audio_features,
+ text_features,
+ audio_features_mlp,
+ text_features_mlp,
+ self.logit_scale_a.exp(),
+ self.logit_scale_t.exp(),
+ )
+
+ def get_logit_scale(self):
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
+
+ def get_text_embedding(self, data):
+ """Get the text embedding from the model
+
+ Parameters
+ ----------
+ data: torch.Tensor
+ a tensor of text embedding
+
+ Returns
+ ----------
+ text_embed: torch.Tensor
+ a tensor of text_embeds (N, D)
+
+ """
+ device = next(self.parameters()).device
+ for k in data:
+ data[k] = data[k].to(device)
+ if len(data[k].size()) < 2:
+ data[k] = data[k].unsqueeze(0)
+ text_embeds = self.encode_text(data, device=device)
+ text_embeds = F.normalize(text_embeds, dim=-1)
+
+ return text_embeds
+
+ def get_audio_embedding(self, data):
+ """Get the audio embedding from the model
+
+ Parameters
+ ----------
+ data: a list of dict
+ the audio input dict list from 'get_audio_feature' method
+
+ Returns
+ ----------
+ audio_embed: torch.Tensor
+ a tensor of audio_embeds (N, D)
+
+ """
+ device = next(self.parameters()).device
+ input_dict = {}
+ keys = data[0].keys()
+ for k in keys:
+ input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
+ device
+ )
+
+ audio_embeds = self.audio_projection(
+ self.encode_audio(input_dict, device=device)["embedding"]
+ )
+ audio_embeds = F.normalize(audio_embeds, dim=-1)
+
+ return audio_embeds
+
+ def audio_infer(self, audio, hopsize=None, device=None):
+ """Forward one audio and produce the audio embedding
+
+ Parameters
+ ----------
+ audio: (audio_length)
+ the time-domain audio input, notice that it must be only one input
+ hopsize: int
+ the overlap hopsize as the sliding window
+
+ Returns
+ ----------
+ output_dict: {
+ key: [n, (embedding_shape)] if "HTS-AT"
+ or
+ key: [(embedding_shape)] if "PANN"
+ }
+ the list of key values of the audio branch
+
+ """
+
+ assert not self.training, "the inference mode must be run at eval stage"
+ output_dict = {}
+ # PANN
+ if self.audio_cfg.model_type == "PANN":
+ audio_input = audio.unsqueeze(dim=0)
+ output_dict[key] = self.encode_audio(audio_input, device=device)[
+ key
+ ].squeeze(dim=0)
+ elif self.audio_cfg.model_type == "HTSAT":
+ # repeat
+ audio_len = len(audio)
+ k = self.audio_cfg.clip_samples // audio_len
+ if k > 1:
+ audio = audio.repeat(k)
+ audio_len = len(audio)
+
+ if hopsize is None:
+ hopsize = min(hopsize, audio_len)
+
+ if audio_len > self.audio_cfg.clip_samples:
+ audio_input = [
+ audio[pos : pos + self.audio_cfg.clip_samples].clone()
+ for pos in range(
+ 0, audio_len - self.audio_cfg.clip_samples, hopsize
+ )
+ ]
+ audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
+ audio_input = torch.stack(audio_input)
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key]
+ else:
+ audio_input = audio.unsqueeze(dim=0)
+ output_dict[key] = self.encode_audio(audio_input, device=device)[
+ key
+ ].squeeze(dim=0)
+
+ return output_dict
+
+
+def convert_weights_to_fp16(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ if isinstance(l, nn.MultiheadAttention):
+ for attr in [
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
+ "in_proj_bias",
+ "bias_k",
+ "bias_v",
+ ]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+# Ignore the state dict of the vision part
+def build_model_from_openai_state_dict(
+ state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
+):
+
+ embed_dim = model_cfg["embed_dim"]
+ audio_cfg = model_cfg["audio_cfg"]
+ text_cfg = model_cfg["text_cfg"]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(
+ set(
+ k.split(".")[2]
+ for k in state_dict
+ if k.startswith(f"transformer.resblocks")
+ )
+ )
+
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
+ text_cfg = CLAPTextCfg(**text_cfg)
+
+ model = CLAP(
+ embed_dim,
+ audio_cfg=audio_cfg,
+ text_cfg=text_cfg,
+ quick_gelu=True, # OpenAI models were trained with QuickGELU
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+ state_dict["logit_scale_a"] = state_dict["logit_scale"]
+ state_dict["logit_scale_t"] = state_dict["logit_scale"]
+ pop_keys = list(state_dict.keys())[::]
+ # pop the visual branch saved weights
+ for key in pop_keys:
+ if key.startswith("visual."):
+ state_dict.pop(key, None)
+
+ for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
+ state_dict.pop(key, None)
+
+ # not use fp16
+ # convert_weights_to_fp16(model)
+ model.load_state_dict(state_dict, strict=False)
+ return model.eval()
+
+
+def trace_model(model, batch_size=256, device=torch.device("cpu")):
+ model.eval()
+ audio_length = model.audio_cfg.audio_length
+ example_audio = torch.ones((batch_size, audio_length), device=device)
+ example_text = torch.zeros(
+ (batch_size, model.context_length), dtype=torch.int, device=device
+ )
+ model = torch.jit.trace_module(
+ model,
+ inputs=dict(
+ forward=(example_audio, example_text),
+ encode_text=(example_text,),
+ encode_image=(example_audio,),
+ ),
+ )
+ model.audio_cfg.audio_length = audio_length # Question: what does this do?
+ return model
diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-base.json b/audioldm/clap/open_clip/model_configs/HTSAT-base.json
new file mode 100755
index 0000000000000000000000000000000000000000..6cef625a89daf4431f1c9f72e10bc9640eef2ba8
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/HTSAT-base.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 1024,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "HTSAT",
+ "model_name": "base"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-large.json b/audioldm/clap/open_clip/model_configs/HTSAT-large.json
new file mode 100755
index 0000000000000000000000000000000000000000..699cdb1b16855582606551e4196b24aba2ffd871
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/HTSAT-large.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "HTSAT",
+ "model_name": "large"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json b/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json
new file mode 100755
index 0000000000000000000000000000000000000000..73e42990fe8361a0df502e7f93d29f19f58c9ecb
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 768,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1536,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "HTSAT",
+ "model_name": "tiny"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json b/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json
new file mode 100755
index 0000000000000000000000000000000000000000..a6e7821163d9afa81c27345a1e472475b92af169
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 768,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "HTSAT",
+ "model_name": "tiny"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-10.json b/audioldm/clap/open_clip/model_configs/PANN-10.json
new file mode 100755
index 0000000000000000000000000000000000000000..954ddf62921aed7dde9c37ffffec98a2e96a4ee7
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/PANN-10.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 1024,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn10"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json
new file mode 100755
index 0000000000000000000000000000000000000000..b7989bc0cd95d0d39049b7524eba508b3e386439
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 18000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json
new file mode 100755
index 0000000000000000000000000000000000000000..56bdb56bedc304ffa52d8bf5988cea2c1d82d14e
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 960000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 360,
+ "fmin": 50,
+ "fmax": 8000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json b/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json
new file mode 100755
index 0000000000000000000000000000000000000000..5756e3bebc97cc985f512cb081930fee4e49bec1
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 4
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json b/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json
new file mode 100755
index 0000000000000000000000000000000000000000..5a9e7e208b661619d5e26625e849da1adda8a475
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1536,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-14.json b/audioldm/clap/open_clip/model_configs/PANN-14.json
new file mode 100755
index 0000000000000000000000000000000000000000..39a5134cde1d8c50f4758377c952ef22f07bab41
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/PANN-14.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 2048,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn14"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/PANN-6.json b/audioldm/clap/open_clip/model_configs/PANN-6.json
new file mode 100755
index 0000000000000000000000000000000000000000..21ebc344326de260c386ba77e0ad63cf9b04febf
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/PANN-6.json
@@ -0,0 +1,23 @@
+{
+ "embed_dim": 512,
+ "audio_cfg": {
+ "audio_length": 1024,
+ "clip_samples": 480000,
+ "mel_bins": 64,
+ "sample_rate": 48000,
+ "window_size": 1024,
+ "hop_size": 480,
+ "fmin": 50,
+ "fmax": 14000,
+ "class_num": 527,
+ "model_type": "PANN",
+ "model_name": "Cnn6"
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json b/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json
new file mode 100755
index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json
@@ -0,0 +1,22 @@
+{
+ "embed_dim": 512,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 23,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/RN101.json b/audioldm/clap/open_clip/model_configs/RN101.json
new file mode 100755
index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/RN101.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 23,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json b/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json
new file mode 100755
index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json
@@ -0,0 +1,22 @@
+{
+ "embed_dim": 1024,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 6,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
diff --git a/audioldm/clap/open_clip/model_configs/RN50.json b/audioldm/clap/open_clip/model_configs/RN50.json
new file mode 100755
index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/RN50.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 1024,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": [
+ 3,
+ 4,
+ 6,
+ 3
+ ],
+ "width": 64,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/RN50x16.json b/audioldm/clap/open_clip/model_configs/RN50x16.json
new file mode 100755
index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/RN50x16.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "image_size": 384,
+ "layers": [
+ 6,
+ 8,
+ 18,
+ 8
+ ],
+ "width": 96,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/RN50x4.json b/audioldm/clap/open_clip/model_configs/RN50x4.json
new file mode 100755
index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/RN50x4.json
@@ -0,0 +1,21 @@
+{
+ "embed_dim": 640,
+ "vision_cfg": {
+ "image_size": 288,
+ "layers": [
+ 4,
+ 6,
+ 10,
+ 6
+ ],
+ "width": 80,
+ "patch_size": null
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 640,
+ "heads": 10,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-16.json b/audioldm/clap/open_clip/model_configs/ViT-B-16.json
new file mode 100755
index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/ViT-B-16.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 16
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json b/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json
new file mode 100755
index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json
@@ -0,0 +1,17 @@
+{
+ "embed_dim": 512,
+ "quick_gelu": true,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-32.json b/audioldm/clap/open_clip/model_configs/ViT-B-32.json
new file mode 100755
index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/ViT-B-32.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 512,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 12,
+ "width": 768,
+ "patch_size": 32
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 512,
+ "heads": 8,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/model_configs/ViT-L-14.json b/audioldm/clap/open_clip/model_configs/ViT-L-14.json
new file mode 100755
index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241
--- /dev/null
+++ b/audioldm/clap/open_clip/model_configs/ViT-L-14.json
@@ -0,0 +1,16 @@
+{
+ "embed_dim": 768,
+ "vision_cfg": {
+ "image_size": 224,
+ "layers": 24,
+ "width": 1024,
+ "patch_size": 14
+ },
+ "text_cfg": {
+ "context_length": 77,
+ "vocab_size": 49408,
+ "width": 768,
+ "heads": 12,
+ "layers": 12
+ }
+}
\ No newline at end of file
diff --git a/audioldm/clap/open_clip/openai.py b/audioldm/clap/open_clip/openai.py
new file mode 100755
index 0000000000000000000000000000000000000000..fcb624f54a8b9d2c4b11e3adb50c53c3261716d4
--- /dev/null
+++ b/audioldm/clap/open_clip/openai.py
@@ -0,0 +1,159 @@
+""" OpenAI pretrained model functions
+
+Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+
+import os
+import warnings
+from typing import Union, List
+
+import torch
+
+from .model import build_model_from_openai_state_dict
+from .pretrained import (
+ get_pretrained_url,
+ list_pretrained_tag_models,
+ download_pretrained,
+)
+
+__all__ = ["list_openai_models", "load_openai_model"]
+
+CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache")
+
+
+
+def list_openai_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list_pretrained_tag_models("openai")
+
+
+def load_openai_model(
+ name: str,
+ model_cfg,
+ device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
+ jit=True,
+ cache_dir=os.path.expanduser(f"{CACHE_DIR}/clip"),
+ enable_fusion: bool = False,
+ fusion_type: str = "None",
+):
+ """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
+
+ Parameters
+ ----------
+ name : str
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+ device : Union[str, torch.device]
+ The device to put the loaded model
+ jit : bool
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
+
+ Returns
+ -------
+ model : torch.nn.Module
+ The CLAP model
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ if get_pretrained_url(name, "openai"):
+ model_path = download_pretrained(
+ get_pretrained_url(name, "openai"), root=cache_dir
+ )
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(
+ f"Model {name} not found; available models = {list_openai_models()}"
+ )
+
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(
+ f"File {model_path} is not a JIT archive. Loading as a state dict instead"
+ )
+ jit = False
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if not jit:
+ try:
+ model = build_model_from_openai_state_dict(
+ state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
+ ).to(device)
+ except KeyError:
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
+ model = build_model_from_openai_state_dict(
+ sd, model_cfg, enable_fusion, fusion_type
+ ).to(device)
+
+ if str(device) == "cpu":
+ model.float()
+ return model
+
+ # patch the device names
+ device_holder = torch.jit.trace(
+ lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
+ )
+ device_node = [
+ n
+ for n in device_holder.graph.findAllNodes("prim::Constant")
+ if "Device" in repr(n)
+ ][-1]
+
+ def patch_device(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith(
+ "cuda"
+ ):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_audio)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if str(device) == "cpu":
+ float_holder = torch.jit.trace(
+ lambda: torch.ones([]).float(), example_inputs=[]
+ )
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ try:
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ except RuntimeError:
+ graphs = []
+
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [
+ 1,
+ 2,
+ ]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_audio)
+ patch_float(model.encode_text)
+ model.float()
+
+ model.audio_branch.audio_length = model.audio_cfg.audio_length
+ return model
diff --git a/audioldm/clap/open_clip/pann_model.py b/audioldm/clap/open_clip/pann_model.py
new file mode 100755
index 0000000000000000000000000000000000000000..0d9a8eb0bf897ad6ec04923361b01e5de433b2ef
--- /dev/null
+++ b/audioldm/clap/open_clip/pann_model.py
@@ -0,0 +1,704 @@
+# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
+# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
+# Some layers are re-designed for CLAP
+import os
+
+os.environ["NUMBA_CACHE_DIR"] = "/tmp/"
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchlibrosa.stft import Spectrogram, LogmelFilterBank
+from torchlibrosa.augmentation import SpecAugmentation
+
+from .utils import do_mixup, interpolate, pad_framewise_output
+from .feature_fusion import iAFF, AFF, DAF
+
+
+def init_layer(layer):
+ """Initialize a Linear or Convolutional layer."""
+ nn.init.xavier_uniform_(layer.weight)
+
+ if hasattr(layer, "bias"):
+ if layer.bias is not None:
+ layer.bias.data.fill_(0.0)
+
+
+def init_bn(bn):
+ """Initialize a Batchnorm layer."""
+ bn.bias.data.fill_(0.0)
+ bn.weight.data.fill_(1.0)
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_channels, out_channels):
+
+ super(ConvBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=False,
+ )
+
+ self.conv2 = nn.Conv2d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=(3, 3),
+ stride=(1, 1),
+ padding=(1, 1),
+ bias=False,
+ )
+
+ self.bn1 = nn.BatchNorm2d(out_channels)
+ self.bn2 = nn.BatchNorm2d(out_channels)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_layer(self.conv1)
+ init_layer(self.conv2)
+ init_bn(self.bn1)
+ init_bn(self.bn2)
+
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
+
+ x = input
+ x = F.relu_(self.bn1(self.conv1(x)))
+ x = F.relu_(self.bn2(self.conv2(x)))
+ if pool_type == "max":
+ x = F.max_pool2d(x, kernel_size=pool_size)
+ elif pool_type == "avg":
+ x = F.avg_pool2d(x, kernel_size=pool_size)
+ elif pool_type == "avg+max":
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
+ x = x1 + x2
+ else:
+ raise Exception("Incorrect argument!")
+
+ return x
+
+
+class ConvBlock5x5(nn.Module):
+ def __init__(self, in_channels, out_channels):
+
+ super(ConvBlock5x5, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=(5, 5),
+ stride=(1, 1),
+ padding=(2, 2),
+ bias=False,
+ )
+
+ self.bn1 = nn.BatchNorm2d(out_channels)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_layer(self.conv1)
+ init_bn(self.bn1)
+
+ def forward(self, input, pool_size=(2, 2), pool_type="avg"):
+
+ x = input
+ x = F.relu_(self.bn1(self.conv1(x)))
+ if pool_type == "max":
+ x = F.max_pool2d(x, kernel_size=pool_size)
+ elif pool_type == "avg":
+ x = F.avg_pool2d(x, kernel_size=pool_size)
+ elif pool_type == "avg+max":
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
+ x = x1 + x2
+ else:
+ raise Exception("Incorrect argument!")
+
+ return x
+
+
+class AttBlock(nn.Module):
+ def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
+ super(AttBlock, self).__init__()
+
+ self.activation = activation
+ self.temperature = temperature
+ self.att = nn.Conv1d(
+ in_channels=n_in,
+ out_channels=n_out,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ )
+ self.cla = nn.Conv1d(
+ in_channels=n_in,
+ out_channels=n_out,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True,
+ )
+
+ self.bn_att = nn.BatchNorm1d(n_out)
+ self.init_weights()
+
+ def init_weights(self):
+ init_layer(self.att)
+ init_layer(self.cla)
+ init_bn(self.bn_att)
+
+ def forward(self, x):
+ # x: (n_samples, n_in, n_time)
+ norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
+ cla = self.nonlinear_transform(self.cla(x))
+ x = torch.sum(norm_att * cla, dim=2)
+ return x, norm_att, cla
+
+ def nonlinear_transform(self, x):
+ if self.activation == "linear":
+ return x
+ elif self.activation == "sigmoid":
+ return torch.sigmoid(x)
+
+
+class Cnn14(nn.Module):
+ def __init__(
+ self,
+ sample_rate,
+ window_size,
+ hop_size,
+ mel_bins,
+ fmin,
+ fmax,
+ classes_num,
+ enable_fusion=False,
+ fusion_type="None",
+ ):
+
+ super(Cnn14, self).__init__()
+
+ window = "hann"
+ center = True
+ pad_mode = "reflect"
+ ref = 1.0
+ amin = 1e-10
+ top_db = None
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ # Spectrogram extractor
+ self.spectrogram_extractor = Spectrogram(
+ n_fft=window_size,
+ hop_length=hop_size,
+ win_length=window_size,
+ window=window,
+ center=center,
+ pad_mode=pad_mode,
+ freeze_parameters=True,
+ )
+
+ # Logmel feature extractor
+ self.logmel_extractor = LogmelFilterBank(
+ sr=sample_rate,
+ n_fft=window_size,
+ n_mels=mel_bins,
+ fmin=fmin,
+ fmax=fmax,
+ ref=ref,
+ amin=amin,
+ top_db=top_db,
+ freeze_parameters=True,
+ )
+
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(
+ time_drop_width=64,
+ time_stripes_num=2,
+ freq_drop_width=8,
+ freq_stripes_num=2,
+ )
+
+ self.bn0 = nn.BatchNorm2d(64)
+
+ if (self.enable_fusion) and (self.fusion_type == "channel_map"):
+ self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
+ else:
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
+
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
+ self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
+
+ if (self.enable_fusion) and (
+ self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
+ ):
+ self.mel_conv1d = nn.Sequential(
+ nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
+ nn.BatchNorm1d(64), # No Relu
+ )
+ if self.fusion_type == "daf_1d":
+ self.fusion_model = DAF()
+ elif self.fusion_type == "aff_1d":
+ self.fusion_model = AFF(channels=64, type="1D")
+ elif self.fusion_type == "iaff_1d":
+ self.fusion_model = iAFF(channels=64, type="1D")
+
+ if (self.enable_fusion) and (
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
+ ):
+ self.mel_conv2d = nn.Sequential(
+ nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)),
+ nn.BatchNorm2d(64),
+ nn.ReLU(inplace=True),
+ )
+
+ if self.fusion_type == "daf_2d":
+ self.fusion_model = DAF()
+ elif self.fusion_type == "aff_2d":
+ self.fusion_model = AFF(channels=64, type="2D")
+ elif self.fusion_type == "iaff_2d":
+ self.fusion_model = iAFF(channels=64, type="2D")
+ self.init_weight()
+
+ def init_weight(self):
+ init_bn(self.bn0)
+ init_layer(self.fc1)
+ init_layer(self.fc_audioset)
+
+ def forward(self, input, mixup_lambda=None, device=None):
+ """
+ Input: (batch_size, data_length)"""
+
+ if self.enable_fusion and input["longer"].sum() == 0:
+ # if no audio is longer than 10s, then randomly select one audio to be longer
+ input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
+
+ if not self.enable_fusion:
+ x = self.spectrogram_extractor(
+ input["waveform"].to(device=device, non_blocking=True)
+ ) # (batch_size, 1, time_steps, freq_bins)
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
+
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ else:
+ longer_list = input["longer"].to(device=device, non_blocking=True)
+ x = input["mel_fusion"].to(device=device, non_blocking=True)
+ longer_list_idx = torch.where(longer_list)[0]
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+ if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
+ new_x = x[:, 0:1, :, :].clone().contiguous()
+ # local processing
+ if len(longer_list_idx) > 0:
+ fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
+ FB, FC, FT, FF = fusion_x_local.size()
+ fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
+ fusion_x_local = torch.permute(
+ fusion_x_local, (0, 2, 1)
+ ).contiguous()
+ fusion_x_local = self.mel_conv1d(fusion_x_local)
+ fusion_x_local = fusion_x_local.view(
+ FB, FC, FF, fusion_x_local.size(-1)
+ )
+ fusion_x_local = (
+ torch.permute(fusion_x_local, (0, 2, 1, 3))
+ .contiguous()
+ .flatten(2)
+ )
+ if fusion_x_local.size(-1) < FT:
+ fusion_x_local = torch.cat(
+ [
+ fusion_x_local,
+ torch.zeros(
+ (FB, FF, FT - fusion_x_local.size(-1)),
+ device=device,
+ ),
+ ],
+ dim=-1,
+ )
+ else:
+ fusion_x_local = fusion_x_local[:, :, :FT]
+ # 1D fusion
+ new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
+ new_x[longer_list_idx] = self.fusion_model(
+ new_x[longer_list_idx], fusion_x_local
+ )
+ x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
+ else:
+ x = new_x
+ elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
+ x = x # no change
+
+ if self.training:
+ x = self.spec_augmenter(x)
+ # Mixup on spectrogram
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+ if (self.enable_fusion) and (
+ self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
+ ):
+ global_x = x[:, 0:1, :, :]
+
+ # global processing
+ B, C, H, W = global_x.shape
+ global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg")
+ if len(longer_list_idx) > 0:
+ local_x = x[longer_list_idx, 1:, :, :].contiguous()
+ TH = global_x.size(-2)
+ # local processing
+ B, C, H, W = local_x.shape
+ local_x = local_x.view(B * C, 1, H, W)
+ local_x = self.mel_conv2d(local_x)
+ local_x = local_x.view(
+ B, C, local_x.size(1), local_x.size(2), local_x.size(3)
+ )
+ local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3)
+ TB, TC, _, TW = local_x.size()
+ if local_x.size(-2) < TH:
+ local_x = torch.cat(
+ [
+ local_x,
+ torch.zeros(
+ (TB, TC, TH - local_x.size(-2), TW),
+ device=global_x.device,
+ ),
+ ],
+ dim=-2,
+ )
+ else:
+ local_x = local_x[:, :, :TH, :]
+
+ global_x[longer_list_idx] = self.fusion_model(
+ global_x[longer_list_idx], local_x
+ )
+ x = global_x
+ else:
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
+
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = torch.mean(x, dim=3)
+
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x = latent_x1 + latent_x2
+ latent_x = latent_x.transpose(1, 2)
+ latent_x = F.relu_(self.fc1(latent_x))
+ latent_output = interpolate(latent_x, 32)
+
+ (x1, _) = torch.max(x, dim=2)
+ x2 = torch.mean(x, dim=2)
+ x = x1 + x2
+ x = F.dropout(x, p=0.5, training=self.training)
+ x = F.relu_(self.fc1(x))
+ embedding = F.dropout(x, p=0.5, training=self.training)
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
+
+ output_dict = {
+ "clipwise_output": clipwise_output,
+ "embedding": embedding,
+ "fine_grained_embedding": latent_output,
+ }
+ return output_dict
+
+
+class Cnn6(nn.Module):
+ def __init__(
+ self,
+ sample_rate,
+ window_size,
+ hop_size,
+ mel_bins,
+ fmin,
+ fmax,
+ classes_num,
+ enable_fusion=False,
+ fusion_type="None",
+ ):
+
+ super(Cnn6, self).__init__()
+
+ window = "hann"
+ center = True
+ pad_mode = "reflect"
+ ref = 1.0
+ amin = 1e-10
+ top_db = None
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ # Spectrogram extractor
+ self.spectrogram_extractor = Spectrogram(
+ n_fft=window_size,
+ hop_length=hop_size,
+ win_length=window_size,
+ window=window,
+ center=center,
+ pad_mode=pad_mode,
+ freeze_parameters=True,
+ )
+
+ # Logmel feature extractor
+ self.logmel_extractor = LogmelFilterBank(
+ sr=sample_rate,
+ n_fft=window_size,
+ n_mels=mel_bins,
+ fmin=fmin,
+ fmax=fmax,
+ ref=ref,
+ amin=amin,
+ top_db=top_db,
+ freeze_parameters=True,
+ )
+
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(
+ time_drop_width=64,
+ time_stripes_num=2,
+ freq_drop_width=8,
+ freq_stripes_num=2,
+ )
+
+ self.bn0 = nn.BatchNorm2d(64)
+
+ self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
+ self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
+ self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
+ self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
+
+ self.fc1 = nn.Linear(512, 512, bias=True)
+ self.fc_audioset = nn.Linear(512, classes_num, bias=True)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_bn(self.bn0)
+ init_layer(self.fc1)
+ init_layer(self.fc_audioset)
+
+ def forward(self, input, mixup_lambda=None, device=None):
+ """
+ Input: (batch_size, data_length)"""
+
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
+
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+
+ if self.training:
+ x = self.spec_augmenter(x)
+
+ # Mixup on spectrogram
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = torch.mean(x, dim=3)
+
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x = latent_x1 + latent_x2
+ latent_x = latent_x.transpose(1, 2)
+ latent_x = F.relu_(self.fc1(latent_x))
+ latent_output = interpolate(latent_x, 16)
+
+ (x1, _) = torch.max(x, dim=2)
+ x2 = torch.mean(x, dim=2)
+ x = x1 + x2
+ x = F.dropout(x, p=0.5, training=self.training)
+ x = F.relu_(self.fc1(x))
+ embedding = F.dropout(x, p=0.5, training=self.training)
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
+
+ output_dict = {
+ "clipwise_output": clipwise_output,
+ "embedding": embedding,
+ "fine_grained_embedding": latent_output,
+ }
+
+ return output_dict
+
+
+class Cnn10(nn.Module):
+ def __init__(
+ self,
+ sample_rate,
+ window_size,
+ hop_size,
+ mel_bins,
+ fmin,
+ fmax,
+ classes_num,
+ enable_fusion=False,
+ fusion_type="None",
+ ):
+
+ super(Cnn10, self).__init__()
+
+ window = "hann"
+ center = True
+ pad_mode = "reflect"
+ ref = 1.0
+ amin = 1e-10
+ top_db = None
+
+ self.enable_fusion = enable_fusion
+ self.fusion_type = fusion_type
+
+ # Spectrogram extractor
+ self.spectrogram_extractor = Spectrogram(
+ n_fft=window_size,
+ hop_length=hop_size,
+ win_length=window_size,
+ window=window,
+ center=center,
+ pad_mode=pad_mode,
+ freeze_parameters=True,
+ )
+
+ # Logmel feature extractor
+ self.logmel_extractor = LogmelFilterBank(
+ sr=sample_rate,
+ n_fft=window_size,
+ n_mels=mel_bins,
+ fmin=fmin,
+ fmax=fmax,
+ ref=ref,
+ amin=amin,
+ top_db=top_db,
+ freeze_parameters=True,
+ )
+
+ # Spec augmenter
+ self.spec_augmenter = SpecAugmentation(
+ time_drop_width=64,
+ time_stripes_num=2,
+ freq_drop_width=8,
+ freq_stripes_num=2,
+ )
+
+ self.bn0 = nn.BatchNorm2d(64)
+
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
+
+ self.fc1 = nn.Linear(1024, 1024, bias=True)
+ self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
+
+ self.init_weight()
+
+ def init_weight(self):
+ init_bn(self.bn0)
+ init_layer(self.fc1)
+ init_layer(self.fc_audioset)
+
+ def forward(self, input, mixup_lambda=None, device=None):
+ """
+ Input: (batch_size, data_length)"""
+
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
+
+ x = x.transpose(1, 3)
+ x = self.bn0(x)
+ x = x.transpose(1, 3)
+
+ if self.training:
+ x = self.spec_augmenter(x)
+
+ # Mixup on spectrogram
+ if self.training and mixup_lambda is not None:
+ x = do_mixup(x, mixup_lambda)
+
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
+ x = F.dropout(x, p=0.2, training=self.training)
+ x = torch.mean(x, dim=3)
+
+ latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
+ latent_x = latent_x1 + latent_x2
+ latent_x = latent_x.transpose(1, 2)
+ latent_x = F.relu_(self.fc1(latent_x))
+ latent_output = interpolate(latent_x, 32)
+
+ (x1, _) = torch.max(x, dim=2)
+ x2 = torch.mean(x, dim=2)
+ x = x1 + x2
+ x = F.dropout(x, p=0.5, training=self.training)
+ x = F.relu_(self.fc1(x))
+ embedding = F.dropout(x, p=0.5, training=self.training)
+ clipwise_output = torch.sigmoid(self.fc_audioset(x))
+
+ output_dict = {
+ "clipwise_output": clipwise_output,
+ "embedding": embedding,
+ "fine_grained_embedding": latent_output,
+ }
+
+ return output_dict
+
+
+def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"):
+ try:
+ ModelProto = eval(audio_cfg.model_name)
+ model = ModelProto(
+ sample_rate=audio_cfg.sample_rate,
+ window_size=audio_cfg.window_size,
+ hop_size=audio_cfg.hop_size,
+ mel_bins=audio_cfg.mel_bins,
+ fmin=audio_cfg.fmin,
+ fmax=audio_cfg.fmax,
+ classes_num=audio_cfg.class_num,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+ return model
+ except:
+ raise RuntimeError(
+ f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
+ )
diff --git a/audioldm/clap/open_clip/pretrained.py b/audioldm/clap/open_clip/pretrained.py
new file mode 100755
index 0000000000000000000000000000000000000000..8ed2ae1732a28c4e98d1f3412157ef27054e41dc
--- /dev/null
+++ b/audioldm/clap/open_clip/pretrained.py
@@ -0,0 +1,169 @@
+import hashlib
+import os
+import urllib
+import warnings
+
+from tqdm import tqdm
+
+CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache")
+
+_RN50 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
+ cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
+)
+
+_RN50_quickgelu = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
+ cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
+)
+
+_RN101 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
+)
+
+_RN101_quickgelu = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
+)
+
+_RN50x4 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+)
+
+_RN50x16 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
+)
+
+_RN50x64 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
+)
+
+_VITB32 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
+ laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
+ laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
+)
+
+_VITB32_quickgelu = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+ laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
+ laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
+ laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
+)
+
+_VITB16 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
+)
+
+_VITL14 = dict(
+ openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
+)
+
+_PRETRAINED = {
+ "RN50": _RN50,
+ "RN50-quickgelu": _RN50_quickgelu,
+ "RN101": _RN101,
+ "RN101-quickgelu": _RN101_quickgelu,
+ "RN50x4": _RN50x4,
+ "RN50x16": _RN50x16,
+ "ViT-B-32": _VITB32,
+ "ViT-B-32-quickgelu": _VITB32_quickgelu,
+ "ViT-B-16": _VITB16,
+ "ViT-L-14": _VITL14,
+}
+
+
+def list_pretrained(as_str: bool = False):
+ """returns list of pretrained models
+ Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
+ """
+ return [
+ ":".join([k, t]) if as_str else (k, t)
+ for k in _PRETRAINED.keys()
+ for t in _PRETRAINED[k].keys()
+ ]
+
+
+def list_pretrained_tag_models(tag: str):
+ """return all models having the specified pretrain tag"""
+ models = []
+ for k in _PRETRAINED.keys():
+ if tag in _PRETRAINED[k]:
+ models.append(k)
+ return models
+
+
+def list_pretrained_model_tags(model: str):
+ """return all pretrain tags for the specified model architecture"""
+ tags = []
+ if model in _PRETRAINED:
+ tags.extend(_PRETRAINED[model].keys())
+ return tags
+
+
+def get_pretrained_url(model: str, tag: str):
+ if model not in _PRETRAINED:
+ return ""
+ model_pretrained = _PRETRAINED[model]
+ if tag not in model_pretrained:
+ return ""
+ return model_pretrained[tag]
+
+
+def download_pretrained(url: str, root: str = os.path.expanduser(f"{CACHE_DIR}/clip")):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ if "openaipublic" in url:
+ expected_sha256 = url.split("/")[-2]
+ else:
+ expected_sha256 = ""
+
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if expected_sha256:
+ if (
+ hashlib.sha256(open(download_target, "rb").read()).hexdigest()
+ == expected_sha256
+ ):
+ return download_target
+ else:
+ warnings.warn(
+ f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
+ )
+ else:
+ return download_target
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(
+ total=int(source.info().get("Content-Length")),
+ ncols=80,
+ unit="iB",
+ unit_scale=True,
+ ) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if (
+ expected_sha256
+ and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
+ != expected_sha256
+ ):
+ raise RuntimeError(
+ f"Model has been downloaded but the SHA256 checksum does not not match"
+ )
+
+ return download_target
diff --git a/audioldm/clap/open_clip/timm_model.py b/audioldm/clap/open_clip/timm_model.py
new file mode 100755
index 0000000000000000000000000000000000000000..c9d1ab4666b5bab5038d44b90c9ddca5087de460
--- /dev/null
+++ b/audioldm/clap/open_clip/timm_model.py
@@ -0,0 +1,112 @@
+""" timm model adapter
+
+Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
+"""
+from collections import OrderedDict
+
+import torch.nn as nn
+
+try:
+ import timm
+ from timm.models.layers import Mlp, to_2tuple
+ from timm.models.layers.attention_pool2d import RotAttentionPool2d
+ from timm.models.layers.attention_pool2d import (
+ AttentionPool2d as AbsAttentionPool2d,
+ )
+except ImportError as e:
+ timm = None
+
+from .utils import freeze_batch_norm_2d
+
+
+class TimmModel(nn.Module):
+ """timm model adapter
+ # FIXME this adapter is a work in progress, may change in ways that break weight compat
+ """
+
+ def __init__(
+ self,
+ model_name,
+ embed_dim,
+ image_size=224,
+ pool="avg",
+ proj="linear",
+ drop=0.0,
+ pretrained=False,
+ ):
+ super().__init__()
+ if timm is None:
+ raise RuntimeError("Please `pip install timm` to use timm models.")
+
+ self.image_size = to_2tuple(image_size)
+ self.trunk = timm.create_model(model_name, pretrained=pretrained)
+ feat_size = self.trunk.default_cfg.get("pool_size", None)
+ feature_ndim = 1 if not feat_size else 2
+ if pool in ("abs_attn", "rot_attn"):
+ assert feature_ndim == 2
+ # if attn pooling used, remove both classifier and default pool
+ self.trunk.reset_classifier(0, global_pool="")
+ else:
+ # reset global pool if pool config set, otherwise leave as network default
+ reset_kwargs = dict(global_pool=pool) if pool else {}
+ self.trunk.reset_classifier(0, **reset_kwargs)
+ prev_chs = self.trunk.num_features
+
+ head_layers = OrderedDict()
+ if pool == "abs_attn":
+ head_layers["pool"] = AbsAttentionPool2d(
+ prev_chs, feat_size=feat_size, out_features=embed_dim
+ )
+ prev_chs = embed_dim
+ elif pool == "rot_attn":
+ head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
+ prev_chs = embed_dim
+ else:
+ assert proj, "projection layer needed if non-attention pooling is used."
+
+ # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
+ if proj == "linear":
+ head_layers["drop"] = nn.Dropout(drop)
+ head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
+ elif proj == "mlp":
+ head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
+
+ self.head = nn.Sequential(head_layers)
+
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
+ """lock modules
+ Args:
+ unlocked_groups (int): leave last n layer groups unlocked (default: 0)
+ """
+ if not unlocked_groups:
+ # lock full model
+ for param in self.trunk.parameters():
+ param.requires_grad = False
+ if freeze_bn_stats:
+ freeze_batch_norm_2d(self.trunk)
+ else:
+ # NOTE: partial freeze requires latest timm (master) branch and is subject to change
+ try:
+ # FIXME import here until API stable and in an official release
+ from timm.models.helpers import group_parameters, group_modules
+ except ImportError:
+ raise RuntimeError(
+ "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
+ )
+ matcher = self.trunk.group_matcher()
+ gparams = group_parameters(self.trunk, matcher)
+ max_layer_id = max(gparams.keys())
+ max_layer_id = max_layer_id - unlocked_groups
+ for group_idx in range(max_layer_id + 1):
+ group = gparams[group_idx]
+ for param in group:
+ self.trunk.get_parameter(param).requires_grad = False
+ if freeze_bn_stats:
+ gmodules = group_modules(self.trunk, matcher, reverse=True)
+ gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
+ freeze_batch_norm_2d(self.trunk, gmodules)
+
+ def forward(self, x):
+ x = self.trunk(x)
+ x = self.head(x)
+ return x
diff --git a/audioldm/clap/open_clip/tokenizer.py b/audioldm/clap/open_clip/tokenizer.py
new file mode 100755
index 0000000000000000000000000000000000000000..ee4d28450ec5dd12a79daf38cf3088e9e73c2cd5
--- /dev/null
+++ b/audioldm/clap/open_clip/tokenizer.py
@@ -0,0 +1,197 @@
+""" CLIP tokenizer
+
+Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
+"""
+import gzip
+import html
+import os
+from functools import lru_cache
+from typing import Union, List
+
+import ftfy
+import regex as re
+import torch
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
+ )
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = (
+ list(range(ord("!"), ord("~") + 1))
+ + list(range(ord("¡"), ord("¬") + 1))
+ + list(range(ord("®"), ord("ÿ") + 1))
+ )
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8 + n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
+ merges = merges[1 : 49152 - 256 - 2 + 1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v + "" for v in vocab]
+ for merge in merges:
+ vocab.append("".join(merge))
+ if not special_tokens:
+ special_tokens = ["", ""]
+ else:
+ special_tokens = ["", ""] + special_tokens
+ vocab.extend(special_tokens)
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {t: t for t in special_tokens}
+ special = "|".join(special_tokens)
+ self.pat = re.compile(
+ special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+ re.IGNORECASE,
+ )
+
+ self.vocab_size = len(self.encoder)
+ self.all_special_ids = [self.encoder[t] for t in special_tokens]
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + (token[-1] + "",)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token + ""
+
+ while True:
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+ new_word.append(first + second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = " ".join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
+ bpe_tokens.extend(
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
+ )
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = "".join([self.decoder[token] for token in tokens])
+ text = (
+ bytearray([self.byte_decoder[c] for c in text])
+ .decode("utf-8", errors="replace")
+ .replace("", " ")
+ )
+ return text
+
+
+_tokenizer = SimpleTokenizer()
+
+
+def tokenize(
+ texts: Union[str, List[str]], context_length: int = 77
+) -> torch.LongTensor:
+ """
+ Returns the tokenized representation of given input string(s)
+
+ Parameters
+ ----------
+ texts : Union[str, List[str]]
+ An input string or a list of input strings to tokenize
+ context_length : int
+ The context length to use; all CLIP models use 77 as the context length
+
+ Returns
+ -------
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = _tokenizer.encoder[""]
+ eot_token = _tokenizer.encoder[""]
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > context_length:
+ tokens = tokens[:context_length] # Truncate
+ result[i, : len(tokens)] = torch.tensor(tokens)
+
+ return result
diff --git a/audioldm/clap/open_clip/transform.py b/audioldm/clap/open_clip/transform.py
new file mode 100755
index 0000000000000000000000000000000000000000..77aaa722c4a5544ac50de6df35d3e922f63b111d
--- /dev/null
+++ b/audioldm/clap/open_clip/transform.py
@@ -0,0 +1,45 @@
+from torchvision.transforms import (
+ Normalize,
+ Compose,
+ RandomResizedCrop,
+ InterpolationMode,
+ ToTensor,
+ Resize,
+ CenterCrop,
+)
+
+
+def _convert_to_rgb(image):
+ return image.convert("RGB")
+
+
+def image_transform(
+ image_size: int,
+ is_train: bool,
+ mean=(0.48145466, 0.4578275, 0.40821073),
+ std=(0.26862954, 0.26130258, 0.27577711),
+):
+ normalize = Normalize(mean=mean, std=std)
+ if is_train:
+ return Compose(
+ [
+ RandomResizedCrop(
+ image_size,
+ scale=(0.9, 1.0),
+ interpolation=InterpolationMode.BICUBIC,
+ ),
+ _convert_to_rgb,
+ ToTensor(),
+ normalize,
+ ]
+ )
+ else:
+ return Compose(
+ [
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
+ CenterCrop(image_size),
+ _convert_to_rgb,
+ ToTensor(),
+ normalize,
+ ]
+ )
diff --git a/audioldm/clap/open_clip/utils.py b/audioldm/clap/open_clip/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..34ecbced4cb7e6b6f92154a666e2c7efc7c922c6
--- /dev/null
+++ b/audioldm/clap/open_clip/utils.py
@@ -0,0 +1,362 @@
+import numpy as np
+import torch
+from torch import nn as nn
+from torchvision.ops.misc import FrozenBatchNorm2d
+import logging
+
+# import h5py
+from tqdm import tqdm
+import random
+import json
+import os
+import pathlib
+
+# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
+dataset_split = {
+ "audiocaps": ["train", "valid", "test"],
+ "audioset": ["balanced_train", "unbalanced_train", "eval"],
+ "BBCSoundEffects": ["train", "test"],
+ "Clotho": ["train", "test", "valid"],
+ "free_to_use_sounds": ["train", "test"],
+ "paramount_motion": ["train", "test"],
+ "sonniss_game_effects": ["train", "test"],
+ "wesoundeffects": ["train", "test"],
+ "MACS": ["train", "test"],
+ "freesound": ["train", "test"],
+ "FSD50K": ["train", "test", "valid"],
+ "fsd50k_class_label": ["train", "test", "valid"],
+ "esc50": ["train", "test"],
+ "audiostock": ["train", "test"],
+ "freesound_no_overlap_noesc50": ["train", "test"],
+ "epidemic_sound_effects": ["train", "test"],
+ "VGGSound": ["train", "test"],
+ "urbansound8k_class_label": ["train", "test"],
+ "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
+ "epidemic_sound_effects_t5": ["train", "test"],
+ "WavText5K": ["train", "test"],
+ "esc50_no_overlap": ["train", "test"],
+ "usd8k_no_overlap": ["train", "test"],
+ "fsd50k_200_class_label": ["train", "test", "valid"],
+}
+
+
+def freeze_batch_norm_2d(module, module_match={}, name=""):
+ """
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
+
+ Args:
+ module (torch.nn.Module): Any PyTorch module.
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
+ name (str): Full module name (prefix)
+
+ Returns:
+ torch.nn.Module: Resulting module
+
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
+ """
+ res = module
+ is_match = True
+ if module_match:
+ is_match = name in module_match
+ if is_match and isinstance(
+ module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
+ ):
+ res = FrozenBatchNorm2d(module.num_features)
+ res.num_features = module.num_features
+ res.affine = module.affine
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for child_name, child in module.named_children():
+ full_child_name = ".".join([name, child_name]) if name else child_name
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
+ if new_child is not child:
+ res.add_module(child_name, new_child)
+ return res
+
+
+def exist(dataset_name, dataset_type):
+ """
+ Check if dataset exists
+ """
+ if dataset_type in dataset_split[dataset_name]:
+ return True
+ else:
+ return False
+
+
+def get_tar_path_from_dataset_name(
+ dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None
+):
+ """
+ Get tar path from dataset name and type
+ """
+ output = []
+ for n in dataset_names:
+ if full_dataset is not None and n in full_dataset:
+ current_dataset_types = dataset_split[n]
+ else:
+ current_dataset_types = dataset_types
+ for s in current_dataset_types:
+ tmp = []
+ if islocal:
+ sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
+ if not os.path.exists(sizefilepath_):
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
+ else:
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
+ if not os.path.exists(sizefilepath_):
+ continue
+ sizes = json.load(open(sizefilepath_, "r"))
+ for k in sizes.keys():
+ if islocal:
+ tmp.append(f"{dataset_path}/{n}/{s}/{k}")
+ else:
+ tmp.append(
+ f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
+ )
+ if proportion != 1:
+ tmp = random.sample(tmp, int(proportion * len(tmp)))
+ output.append(tmp)
+ return sum(output, [])
+
+
+def get_tar_path_from_txts(txt_path, islocal, proportion=1):
+ """
+ Get tar path from txt path
+ """
+ if isinstance(txt_path, (list, tuple)):
+ return sum(
+ [
+ get_tar_path_from_txts(
+ txt_path[i], islocal=islocal, proportion=proportion
+ )
+ for i in range(len(txt_path))
+ ],
+ [],
+ )
+ if isinstance(txt_path, str):
+ with open(txt_path) as f:
+ lines = f.readlines()
+ if islocal:
+ lines = [
+ lines[i]
+ .split("\n")[0]
+ .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
+ for i in range(len(lines))
+ ]
+ else:
+ lines = [
+ lines[i].split("\n")[0].replace(".tar", ".tar -")
+ for i in range(len(lines))
+ ]
+ if proportion != 1:
+ print("Sampling tars with proportion of {}".format(proportion))
+ lines = random.sample(lines, int(proportion * len(lines)))
+ return lines
+
+
+def get_mix_lambda(mixup_alpha, batch_size):
+ mixup_lambdas = [
+ np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
+ ]
+ return np.array(mixup_lambdas).astype(np.float32)
+
+
+def do_mixup(x, mixup_lambda):
+ """
+ Args:
+ x: (batch_size , ...)
+ mixup_lambda: (batch_size,)
+ Returns:
+ out: (batch_size, ...)
+ """
+ out = (
+ x.transpose(0, -1) * mixup_lambda
+ + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
+ ).transpose(0, -1)
+ return out
+
+
+def interpolate(x, ratio):
+ """Interpolate data in time domain. This is used to compensate the
+ resolution reduction in downsampling of a CNN.
+
+ Args:
+ x: (batch_size, time_steps, classes_num)
+ ratio: int, ratio to interpolate
+ Returns:
+ upsampled: (batch_size, time_steps * ratio, classes_num)
+ """
+ (batch_size, time_steps, classes_num) = x.shape
+ upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
+ upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
+ return upsampled
+
+
+def pad_framewise_output(framewise_output, frames_num):
+ """Pad framewise_output to the same length as input frames. The pad value
+ is the same as the value of the last frame.
+ Args:
+ framewise_output: (batch_size, frames_num, classes_num)
+ frames_num: int, number of frames to pad
+ Outputs:
+ output: (batch_size, frames_num, classes_num)
+ """
+ pad = framewise_output[:, -1:, :].repeat(
+ 1, frames_num - framewise_output.shape[1], 1
+ )
+ """tensor for padding"""
+
+ output = torch.cat((framewise_output, pad), dim=1)
+ """(batch_size, frames_num, classes_num)"""
+
+
+# def process_ipc(index_path, classes_num, filename):
+# # load data
+# logging.info("Load Data...............")
+# ipc = [[] for _ in range(classes_num)]
+# with h5py.File(index_path, "r") as f:
+# for i in tqdm(range(len(f["target"]))):
+# t_class = np.where(f["target"][i])[0]
+# for t in t_class:
+# ipc[t].append(i)
+# print(ipc)
+# np.save(filename, ipc)
+# logging.info("Load Data Succeed...............")
+
+
+def save_to_dict(s, o_={}):
+ sp = s.split(": ")
+ o_.update({sp[0]: float(sp[1])})
+ return o_
+
+
+def get_data_from_log(txt_path):
+ """
+ Output dictionary from out.txt log file
+ """
+ with open(txt_path) as f:
+ lines = f.readlines()
+ val_data = {}
+ train_data = {}
+ train_losses = []
+ train_losses_epoch = []
+ for i in range(len(lines)):
+ if "| INFO |" in lines[i]:
+ if "Eval Epoch" in lines[i]:
+ if "val_loss" in lines[i]:
+ # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
+ line = lines[i].split("Eval Epoch: ")[-1]
+ num_epoch = int(line.split(" ")[0].split(" ")[0])
+ d = {
+ line.split(" ")[0]
+ .split(" ")[1]
+ .replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
+ }
+ for i in range(1, len(line.split(" "))):
+ d = save_to_dict(line.split(" ")[i], d)
+ val_data[num_epoch] = d
+ elif "Train Epoch" in lines[i]:
+ num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
+ loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
+ train_losses.append(loss)
+ train_losses_epoch.append(num_epoch)
+ for i in range(len(train_losses)):
+ train_data[i] = {
+ "num_epoch": train_losses_epoch[i],
+ "train_loss": train_losses[i],
+ }
+ return train_data, val_data
+
+
+def save_p(obj, filename):
+ import pickle
+
+ try:
+ from deepdiff import DeepDiff
+ except:
+ os.system("pip install deepdiff")
+ from deepdiff import DeepDiff
+ with open(filename, "wb") as file:
+ pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
+ with open(filename, "rb") as file:
+ z = pickle.load(file)
+ assert (
+ DeepDiff(obj, z, ignore_string_case=True) == {}
+ ), "there is something wrong with the saving process"
+ return
+
+
+def load_p(filename):
+ import pickle
+
+ with open(filename, "rb") as file:
+ z = pickle.load(file)
+ return z
+
+
+def save_json(data, name="data.json"):
+ import json
+
+ with open(name, "w") as fp:
+ json.dump(data, fp)
+ return
+
+
+def load_json(name):
+ import json
+
+ with open(name, "r") as fp:
+ data = json.load(fp)
+ return data
+
+
+from multiprocessing import Process, Manager
+from multiprocessing import Process, Value, Array
+from ctypes import c_wchar
+
+
+def load_class_label(path):
+ # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
+ # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
+ out = None
+ if path is not None:
+ if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
+ out = load_p(path)
+ elif pathlib.Path(path).suffix in [".json", ".txt"]:
+ out = load_json(path)
+ elif pathlib.Path(path).suffix in [".npy", ".npz"]:
+ out = np.load(path)
+ elif pathlib.Path(path).suffix in [".csv"]:
+ import pandas as pd
+
+ out = pd.read_csv(path)
+ return out
+ # if out is None:
+ # return None
+ # else:
+ # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
+ # val = Array('i', out.values(), lock=False)
+ # return (key, val)
+
+
+from torch import optim
+
+
+def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
+ if optimizer_name.lower() == "adamw":
+ optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps)
+ elif optimizer_name.lower() == "sgd":
+ optimizer = optim.SGD(params, lr=lr, momentum=momentum)
+ elif optimizer_name.lower() == "adam":
+ optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps)
+ else:
+ raise ValueError("optimizer name is not correct")
+ return optimizer
diff --git a/audioldm/clap/open_clip/version.py b/audioldm/clap/open_clip/version.py
new file mode 100755
index 0000000000000000000000000000000000000000..3ced3581bb601ae91b1e1da4b8f4f520855a065e
--- /dev/null
+++ b/audioldm/clap/open_clip/version.py
@@ -0,0 +1 @@
+__version__ = "0.2.1"
diff --git a/audioldm/clap/training/__init__.py b/audioldm/clap/training/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm/clap/training/audioset_textmap.npy b/audioldm/clap/training/audioset_textmap.npy
new file mode 100755
index 0000000000000000000000000000000000000000..3da4c92d3819aaec11e5f576464a9973a6df811b
--- /dev/null
+++ b/audioldm/clap/training/audioset_textmap.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b
+size 84448
diff --git a/audioldm/clap/training/data.py b/audioldm/clap/training/data.py
new file mode 100755
index 0000000000000000000000000000000000000000..a005fee2f51e577446839b8cffd117d9ae93abc9
--- /dev/null
+++ b/audioldm/clap/training/data.py
@@ -0,0 +1,981 @@
+import ast
+import json
+import logging
+import math
+import os
+import random
+
+# import h5py
+from dataclasses import dataclass
+from audioldm.clap.training.params import parse_args
+
+# import braceexpand
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.datasets as datasets
+import torchvision.transforms
+
+# import webdataset as wds
+from PIL import Image
+from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
+from torch.utils.data.distributed import DistributedSampler
+from functools import partial
+import soundfile as sf
+import io
+from pathlib import Path
+
+# import wget
+
+from audioldm.clap.open_clip.utils import (
+ get_tar_path_from_dataset_name,
+ dataset_split,
+)
+from audioldm.clap.open_clip.utils import load_p, load_class_label
+import copy
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+try:
+ import torchaudio
+except ImportError:
+ torchaudio = None
+
+from audioldm.clap.open_clip import tokenize
+
+
+def tokenizer(text):
+ return tokenize(text).squeeze(0)
+
+
+from transformers import RobertaTokenizer
+
+tokenize = RobertaTokenizer.from_pretrained("roberta-base")
+
+
+def tokenizer(text):
+ result = tokenize(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ )
+ return {k: v.squeeze(0) for k, v in result.items()}
+
+
+# initizlied the audioset map
+_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy")
+_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True)
+
+
+def int16_to_float32(x):
+ return (x / 32767.0).astype(np.float32)
+
+
+def float32_to_int16(x):
+ x = np.clip(x, a_min=-1.0, a_max=1.0)
+ return (x * 32767.0).astype(np.int16)
+
+
+# For Toy Dataset
+# class ToyDataset(Dataset):
+# def __init__(self, index_path, ipc, config, eval_mode=False):
+# """Toy Dataset for testing the audioset input with text labels
+# Parameters
+# ----------
+# index_path: str
+# the link to the h5 file of each audio
+# idc: str
+# the link to the npy file, the number of samples in each class
+# config: dict
+# the audio cfg file
+# eval_model (bool): to indicate if the dataset is a testing dataset
+# """
+# self.audio_cfg = config["audio_cfg"]
+# self.text_cfg = config["text_cfg"]
+# self.fp = h5py.File(index_path, "r")
+# self.ipc = np.load(ipc, allow_pickle=True)
+# self.total_size = len(self.fp["audio_name"])
+# self.classes_num = self.audio_cfg["class_num"]
+# self.eval_mode = eval_mode
+
+# if not eval_mode:
+# self.generate_queue()
+# else:
+# self.queue = []
+# for i in range(self.total_size):
+# target = self.fp["target"][i]
+# if np.sum(target) > 0:
+# self.queue.append(i)
+# self.total_size = len(self.queue)
+# logging.info("total dataset size: %d" % (self.total_size))
+# logging.info("class num: %d" % (self.classes_num))
+
+# def time_shifting(self, x):
+# frame_num = len(x)
+# shift_len = random.randint(0, frame_num - 1)
+# new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0)
+# return new_sample
+
+# def generate_queue(self):
+# self.queue = []
+# while len(self.queue) < self.total_size:
+# class_set = [*range(self.classes_num)]
+# random.shuffle(class_set)
+# self.queue += [
+# self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set
+# ]
+# self.queue = self.queue[: self.total_size]
+
+# logging.info("queue regenerated:%s" % (self.queue[-5:]))
+
+# def crop_wav(self, x):
+# crop_size = self.audio_cfg["crop_size"]
+# crop_pos = random.randint(0, len(x) - crop_size - 1)
+# return x[crop_pos : crop_pos + crop_size]
+
+# def prompt_text(self, target):
+# events = _AUDIOSET_MAP[np.where(target > 0)]
+# event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1]
+# text = tokenize(event_text)[0]
+# return text
+
+# def __getitem__(self, index):
+# """Load waveform, text, and target of an audio clip
+
+# Parameters
+# ----------
+# index: int
+# the index number
+# Return
+# ------
+# output: dict {
+# "hdf5_path": str,
+# "index_in_hdf5": int,
+# "audio_name": str,
+# "waveform": list (audio_length,),
+# "target": list (class_num, ),
+# "text": torch.tensor (context_length,)
+# }
+# the output dictionary
+# """
+# s_index = self.queue[index]
+
+# audio_name = self.fp["audio_name"][s_index].decode()
+# # Hardcode here CHANGE
+# hdf5_path = (
+# self.fp["hdf5_path"][s_index]
+# .decode()
+# .replace(
+# "../workspace",
+# "/home/la/kechen/Research/ke_zsasp/workspace",
+# )
+# )
+# r_idx = self.fp["index_in_hdf5"][s_index]
+# target = self.fp["target"][s_index].astype(np.float32)
+# text = self.prompt_text(target)
+# with h5py.File(hdf5_path, "r") as f:
+# waveform = int16_to_float32(f["waveform"][r_idx])[
+# : self.audio_cfg["clip_samples"]
+# ]
+# assert (
+# len(waveform) == self.audio_cfg["clip_samples"]
+# ), "The sample length is not match"
+# # Time shift
+# # if (self.config.enable_time_shift) and (not self.eval_mode):
+# # waveform = self.time_shifting(waveform)
+# # # Label Enhance
+# # if (self.config.crop_size is not None) and (not self.eval_mode):
+# # waveform = self.crop_wav(waveform)
+# # # the label enhance rate is fixed 0.5
+# # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5:
+# # kidx = np.where(target)[0]
+# # for k in kidx:
+# # for add_key in self.class_map[k][1]:
+# # target[add_key] = 1.0
+# # if len(self.class_map[k][2]) > 0:
+# # add_key = random.choice(self.class_map[k][2])
+# # target[add_key] = 1.0
+
+# # missing the text input
+# mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :]
+# mel_spec = (
+# torch.cat(
+# [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0
+# )
+# .cpu()
+# .numpy()
+# )
+# longer = random.choice([True, False])
+# if longer == False:
+# mel_spec[1:, :, :] = 0.0
+# data_dict = {
+# "hdf5_path": hdf5_path,
+# "index_in_hdf5": r_idx,
+# "audio_name": audio_name,
+# "waveform": waveform,
+# "class_label": target,
+# "text": text,
+# "longer": longer,
+# "mel_fusion": mel_spec,
+# }
+# return data_dict
+
+# def __len__(self):
+# return self.total_size
+
+
+class CsvDataset(Dataset):
+ def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
+ logging.debug(f"Loading csv data from {input_filename}.")
+ df = pd.read_csv(input_filename, sep=sep)
+
+ self.images = df[img_key].tolist()
+ self.captions = df[caption_key].tolist()
+ self.transforms = transforms
+ logging.debug("Done loading data.")
+
+ def __len__(self):
+ return len(self.captions)
+
+ def __getitem__(self, idx):
+ images = self.transforms(Image.open(str(self.images[idx])))
+ texts = tokenize([str(self.captions[idx])])[0]
+ return images, texts
+
+
+@dataclass
+class DataInfo:
+ dataloader: DataLoader
+ sampler: DistributedSampler
+
+
+def preprocess_txt(text):
+ return tokenize([str(text)])[0]
+
+
+def get_dataset_size(shards, sizefilepath_=None, is_local=True):
+ if isinstance(shards, list):
+ size_list = []
+ for s in shards:
+ size_list.append(
+ get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0]
+ )
+ else:
+ if not is_local:
+ for n in dataset_split.keys():
+ if n in shards.split("/"):
+ break
+ for s in dataset_split[n]:
+ if s in shards.split("/"):
+ break
+ sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
+ shards_list = list(braceexpand.braceexpand(shards))
+ dir_path = os.path.dirname(shards)
+ if sizefilepath_ is not None:
+ sizes = json.load(open(sizefilepath_, "r"))
+ total_size = sum(
+ [
+ int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))])
+ for shard in shards_list
+ ]
+ )
+ else:
+ sizes_filename = os.path.join(dir_path, "sizes.json")
+ len_filename = os.path.join(dir_path, "__len__")
+ if os.path.exists(sizes_filename):
+ sizes = json.load(open(sizes_filename, "r"))
+ total_size = sum(
+ [int(sizes[os.path.basename(shard)]) for shard in shards_list]
+ )
+ elif os.path.exists(len_filename):
+ # FIXME this used to be eval(open(...)) but that seemed rather unsafe
+ total_size = ast.literal_eval(open(len_filename, "r").read())
+ else:
+ raise Exception(
+ "Cannot find sizes file for dataset. Please specify the path to the file."
+ )
+ # total_size = None # num samples undefined
+ # some common dataset sizes (at time of authors last download)
+ # cc3m-train: 2905954
+ # cc12m: 10968539
+ # LAION-400m: 407332084
+ num_shards = len(shards_list)
+ if isinstance(shards, list):
+ return sum(size_list), len(shards)
+ else:
+ return total_size, num_shards
+
+
+def get_imagenet(args, preprocess_fns, split):
+ assert split in ["train", "val", "v2"]
+ is_train = split == "train"
+ preprocess_train, preprocess_val = preprocess_fns
+
+ if split == "v2":
+ from imagenetv2_pytorch import ImageNetV2Dataset
+
+ dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
+ else:
+ if is_train:
+ data_path = args.imagenet_train
+ preprocess_fn = preprocess_train
+ else:
+ data_path = args.imagenet_val
+ preprocess_fn = preprocess_val
+ assert data_path
+
+ dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
+
+ if is_train:
+ idxs = np.zeros(len(dataset.targets))
+ target_array = np.array(dataset.targets)
+ k = 50
+ for c in range(1000):
+ m = target_array == c
+ n = len(idxs[m])
+ arr = np.zeros(n)
+ arr[:k] = 1
+ np.random.shuffle(arr)
+ idxs[m] = arr
+
+ idxs = idxs.astype("int")
+ sampler = SubsetRandomSampler(np.where(idxs)[0])
+ else:
+ sampler = None
+
+ dataloader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ num_workers=args.workers,
+ sampler=sampler,
+ )
+
+ return DataInfo(dataloader, sampler)
+
+
+def count_samples(dataloader):
+ os.environ["WDS_EPOCH"] = "0"
+ n_elements, n_batches = 0, 0
+ for images, texts in dataloader:
+ n_batches += 1
+ n_elements += len(images)
+ assert len(images) == len(texts)
+ return n_elements, n_batches
+
+
+def filter_no_caption(sample):
+ return "txt" in sample
+
+
+def log_and_continue(exn):
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
+ logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
+ return True
+
+
+_SHARD_SHUFFLE_SIZE = 2000
+_SHARD_SHUFFLE_INITIAL = 500
+_SAMPLE_SHUFFLE_SIZE = 5000
+_SAMPLE_SHUFFLE_INITIAL = 1000
+
+
+def sample_prop(sizefile, inputs, proportion, is_local=True):
+ """
+ Sample a proportion of the data.
+ """
+ file_path_dict = {
+ os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0]
+ for i in range(len(inputs))
+ }
+ sampled_filepath_dict = {}
+ sampled_size_dict = {}
+ if not is_local:
+ if os.path.exists("sizes.json"):
+ os.remove("sizes.json")
+ wget.download(sizefile, "sizes.json")
+ sizefile = "sizes.json"
+ with open(sizefile, "r", encoding="UTF-8") as f:
+ load_dict = json.load(f)
+ L = int(len(file_path_dict) * proportion)
+ subkeys = random.sample(file_path_dict.keys(), L)
+ for k in subkeys:
+ sampled_size_dict[k] = load_dict[k]
+ sampled_filepath_dict[k] = file_path_dict[k]
+ return (
+ sum(sampled_size_dict.values()),
+ L,
+ [os.path.join(v, k) for k, v in sampled_filepath_dict.items()],
+ sampled_size_dict,
+ )
+
+
+def get_mel(audio_data, audio_cfg):
+ # mel shape: (n_mels, T)
+ mel = torchaudio.transforms.MelSpectrogram(
+ sample_rate=audio_cfg["sample_rate"],
+ n_fft=audio_cfg["window_size"],
+ win_length=audio_cfg["window_size"],
+ hop_length=audio_cfg["hop_size"],
+ center=True,
+ pad_mode="reflect",
+ power=2.0,
+ norm=None,
+ onesided=True,
+ n_mels=64,
+ f_min=audio_cfg["fmin"],
+ f_max=audio_cfg["fmax"],
+ ).to(audio_data.device)
+ mel = mel(audio_data)
+ # Align to librosa:
+ # librosa_melspec = librosa.feature.melspectrogram(
+ # waveform,
+ # sr=audio_cfg['sample_rate'],
+ # n_fft=audio_cfg['window_size'],
+ # hop_length=audio_cfg['hop_size'],
+ # win_length=audio_cfg['window_size'],
+ # center=True,
+ # pad_mode="reflect",
+ # power=2.0,
+ # n_mels=64,
+ # norm=None,
+ # htk=True,
+ # f_min=audio_cfg['fmin'],
+ # f_max=audio_cfg['fmax']
+ # )
+ # we use log mel spectrogram as input
+ mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
+ return mel.T # (T, n_mels)
+
+
+def get_audio_features(
+ sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
+):
+ """
+ Calculate and add audio features to sample.
+ Sample: a dict containing all the data of current sample.
+ audio_data: a tensor of shape (T) containing audio data.
+ max_len: the maximum length of audio data.
+ data_truncating: the method of truncating data.
+ data_filling: the method of filling data.
+ audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
+ """
+ with torch.no_grad():
+ if len(audio_data) > max_len:
+ if data_truncating == "rand_trunc":
+ longer = torch.tensor([True])
+ elif data_truncating == "fusion":
+ # fusion
+ mel = get_mel(audio_data, audio_cfg)
+ # split to three parts
+ chunk_frames = (
+ max_len // audio_cfg["hop_size"] + 1
+ ) # the +1 related to how the spectrogram is computed
+ total_frames = mel.shape[0]
+ if chunk_frames == total_frames:
+ # there is a corner case where the audio length is
+ # larger than max_len but smaller than max_len+hop_size.
+ # In this case, we just use the whole audio.
+ mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
+ sample["mel_fusion"] = mel_fusion
+ longer = torch.tensor([False])
+ else:
+ ranges = np.array_split(
+ list(range(0, total_frames - chunk_frames + 1)), 3
+ )
+ # print('total_frames-chunk_frames:', total_frames-chunk_frames,
+ # 'len(audio_data):', len(audio_data),
+ # 'chunk_frames:', chunk_frames,
+ # 'total_frames:', total_frames)
+ if len(ranges[1]) == 0:
+ # if the audio is too short, we just use the first chunk
+ ranges[1] = [0]
+ if len(ranges[2]) == 0:
+ # if the audio is too short, we just use the first chunk
+ ranges[2] = [0]
+ # randomly choose index for each part
+ idx_front = np.random.choice(ranges[0])
+ idx_middle = np.random.choice(ranges[1])
+ idx_back = np.random.choice(ranges[2])
+ # select mel
+ mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :]
+ mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :]
+ mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :]
+
+ # shrink the mel
+ mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])(
+ mel[None]
+ )[0]
+ # logging.info(f"mel_shrink.shape: {mel_shrink.shape}")
+
+ # stack
+ mel_fusion = torch.stack(
+ [mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink],
+ dim=0,
+ )
+ sample["mel_fusion"] = mel_fusion
+ longer = torch.tensor([True])
+ else:
+ raise NotImplementedError(
+ f"data_truncating {data_truncating} not implemented"
+ )
+ # random crop to max_len (for compatibility)
+ overflow = len(audio_data) - max_len
+ idx = np.random.randint(0, overflow + 1)
+ audio_data = audio_data[idx : idx + max_len]
+
+ else: # padding if too short
+ if len(audio_data) < max_len: # do nothing if equal
+ if data_filling == "repeatpad":
+ n_repeat = int(max_len / len(audio_data))
+ audio_data = audio_data.repeat(n_repeat)
+ # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0)
+ # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0]
+ audio_data = F.pad(
+ audio_data,
+ (0, max_len - len(audio_data)),
+ mode="constant",
+ value=0,
+ )
+ elif data_filling == "pad":
+ audio_data = F.pad(
+ audio_data,
+ (0, max_len - len(audio_data)),
+ mode="constant",
+ value=0,
+ )
+ elif data_filling == "repeat":
+ n_repeat = int(max_len / len(audio_data))
+ audio_data = audio_data.repeat(n_repeat + 1)[:max_len]
+ else:
+ raise NotImplementedError(
+ f"data_filling {data_filling} not implemented"
+ )
+ if data_truncating == "fusion":
+ mel = get_mel(audio_data, audio_cfg)
+ mel_fusion = torch.stack([mel, mel, mel, mel], dim=0)
+ sample["mel_fusion"] = mel_fusion
+ longer = torch.tensor([False])
+
+ sample["longer"] = longer
+ sample["waveform"] = audio_data
+
+ return sample
+
+
+def preprocess(
+ sample,
+ audio_ext,
+ text_ext,
+ max_len,
+ audio_cfg,
+ class_index_dict=None,
+ data_filling="pad",
+ data_truncating="rand_trunc",
+ text_augment_selection=None,
+):
+ """
+ Preprocess a single sample for wdsdataloader.
+ """
+ audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
+ audio_data = int16_to_float32(float32_to_int16(audio_data))
+ audio_data = torch.tensor(audio_data).float()
+
+ # TODO: (yusong) to be include in the future
+ # # if torchaudio not installed, use soundfile to load audio
+ # if torchaudio is None:
+ # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
+ # audio_data = torch.tensor(audio_data).float()
+ # else:
+ # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py
+ # with tempfile.TemporaryDirectory() as dirname:
+ # os.makedirs(dirname, exist_ok=True)
+ # fname = os.path.join(dirname, f"file.flac")
+ # with open(fname, "wb") as stream:
+ # stream.write(sample[audio_ext])
+ # audio_data, orig_sr = torchaudio.load(fname)
+ # audio_data = audio_data[0, :].float()
+
+ sample = get_audio_features(
+ sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
+ )
+ del sample[audio_ext]
+
+ try:
+ json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
+ except:
+ print("sample[__url__]:", sample["__url__"])
+
+ # For selecting augmented text from dataset
+ if text_augment_selection is None or text_augment_selection == "none":
+ texts = json_dict_raw["text"]
+ elif text_augment_selection == "all":
+ if "text_augment_all" in json_dict_raw.keys():
+ texts = json_dict_raw["text_augment_all"]
+ else:
+ texts = json_dict_raw["text"]
+ elif text_augment_selection == "augment_only":
+ if "text_augment_all" in json_dict_raw.keys():
+ if json_dict_raw["text_augment_t5"] is None:
+ texts = json_dict_raw["text"]
+ else:
+ texts = json_dict_raw["text_augment_t5"]
+ else:
+ texts = json_dict_raw["text"]
+ else:
+ raise NotImplementedError(
+ f"text_augment_selection {text_augment_selection} not implemented"
+ )
+ sample["full_text"] = texts
+
+ if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
+ texts = random.choice(texts)
+ sample["raw_text"] = texts
+ sample["text"] = tokenizer(texts) # text shape: [num_token]
+ if class_index_dict is not None:
+ # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
+ # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
+ # key, val = class_index_dict
+ # key = key[:].split('\n')
+ # _dict = {k: v for k, v in zip(key, val)}
+ sample["class_label"] = np.zeros(len(class_index_dict.keys()))
+ for x in json_dict_raw["tag"]:
+ sample["class_label"][class_index_dict[x]] = 1
+ sample["class_label"] = torch.tensor(sample["class_label"]).float()
+ del sample[text_ext]
+ sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext
+ sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext
+ sample["audio_orig_sr"] = orig_sr
+ return sample
+
+
+def collate_fn(batch):
+ """
+ Collate function for wdsdataloader.
+ batch: a list of dict, each dict is a sample
+ """
+ # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend.
+ batch_dict = {}
+ for k in batch[0].keys():
+ if isinstance(batch[0][k], dict): # dealwith bert tokenizer output
+ batch_dict[k] = {}
+ for kk in batch[0][k].keys():
+ tmp = []
+ for i in range(len(batch)):
+ tmp.append(batch[i][k][kk])
+ batch_dict[k][kk] = torch.vstack(tmp)
+ elif isinstance(batch[0][k], torch.Tensor):
+ batch_dict[k] = torch.stack([sample[k] for sample in batch])
+ elif isinstance(batch[0][k], np.ndarray):
+ batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch]))
+ else:
+ batch_dict[k] = [sample[k] for sample in batch]
+ return batch_dict
+
+
+def get_wds_dataset(
+ args,
+ model_cfg,
+ is_train,
+ audio_ext="flac",
+ text_ext="json",
+ max_len=480000,
+ proportion=1.0,
+ sizefilepath_=None,
+ is_local=None,
+):
+ """
+ Get a dataset for wdsdataloader.
+ """
+ if is_local is None and (not args.remotedata is None):
+ is_local = not args.remotedata
+
+ input_shards = args.train_data if is_train else args.val_data
+ assert input_shards is not None
+
+ if not sizefilepath_ is None:
+ sizefilepath = sizefilepath_
+ else:
+ sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json")
+
+ if proportion != 1.0:
+ num_samples, num_shards, input_shards, _ = sample_prop(
+ sizefilepath, input_shards, proportion, is_local=is_local
+ )
+ else:
+ num_samples, num_shards = get_dataset_size(
+ input_shards, sizefilepath_=sizefilepath_, is_local=is_local
+ )
+
+ if not num_samples:
+ if is_train:
+ num_samples = args.train_num_samples
+ if not num_samples:
+ raise RuntimeError(
+ "Currently, number of dataset samples must be specified for training dataset. "
+ "Please specify via `--train-num-samples` if no dataset length info present."
+ )
+ else:
+ num_samples = (
+ args.val_num_samples or 0
+ ) # eval will just exhaust the iterator if not specified
+
+ pipeline = [wds.SimpleShardList(input_shards)]
+ # at this point we have an iterator over all the shards
+ # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node
+ if is_train or args.parallel_eval:
+ pipeline.extend(
+ [
+ wds.detshuffle(
+ bufsize=_SHARD_SHUFFLE_SIZE,
+ initial=_SHARD_SHUFFLE_INITIAL,
+ seed=args.seed,
+ ),
+ wds.split_by_node,
+ wds.split_by_worker,
+ # at this point, we have an iterator over the shards assigned to each worker at each node
+ wds.tarfile_to_samples(handler=log_and_continue),
+ wds.shuffle(
+ bufsize=_SAMPLE_SHUFFLE_SIZE,
+ initial=_SAMPLE_SHUFFLE_INITIAL,
+ rng=random.Random(args.seed),
+ ),
+ # wds.repeatedly, # FIXME determine if this is beneficial
+ ]
+ )
+ else:
+ pipeline.extend(
+ [
+ wds.split_by_worker,
+ # at this point, we have an iterator over the shards assigned to each worker
+ wds.tarfile_to_samples(handler=log_and_continue),
+ ]
+ )
+ pipeline.append(
+ wds.map(
+ partial(
+ preprocess,
+ audio_ext=audio_ext,
+ text_ext=text_ext,
+ max_len=max_len,
+ audio_cfg=model_cfg["audio_cfg"],
+ class_index_dict=copy.deepcopy(args.class_index_dict),
+ data_filling=args.data_filling,
+ data_truncating=args.data_truncating,
+ text_augment_selection=args.text_augment_selection,
+ )
+ ),
+ )
+
+ pipeline.append(
+ wds.batched(
+ args.batch_size,
+ partial=not (is_train or args.parallel_eval),
+ collation_fn=collate_fn,
+ )
+ )
+
+ dataset = wds.DataPipeline(*pipeline)
+ if is_train or args.parallel_eval:
+ # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples.
+ # (yusong): See comments below.
+ # roll over and repeat a few samples to get same number of full batches on each node
+ global_batch_size = args.batch_size * args.world_size
+ num_batches = math.ceil(num_samples / global_batch_size)
+ num_workers = max(1, args.workers)
+ num_worker_batches = math.ceil(
+ num_batches / num_workers
+ ) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+ dataset = dataset.with_epoch(
+ num_worker_batches
+ ) # each worker is iterating over this
+ else:
+ # last batches are partial, eval is done on single (master) node
+ num_batches = math.ceil(num_samples / args.batch_size)
+
+ kwargs = {}
+ if args.horovod: # multi-node training on summit
+ kwargs["multiprocessing_context"] = "forkserver"
+
+ dataloader = wds.WebLoader(
+ dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs
+ )
+
+ # FIXME not clear which approach is better, with_epoch before vs after dataloader?
+ # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
+ # if is_train:
+ # # roll over and repeat a few samples to get same number of full batches on each node
+ # global_batch_size = args.batch_size * args.world_size
+ # num_batches = math.ceil(num_samples / global_batch_size)
+ # num_workers = max(1, args.workers)
+ # num_batches = math.ceil(num_batches / num_workers) * num_workers
+ # num_samples = num_batches * global_batch_size
+ # dataloader = dataloader.with_epoch(num_batches)
+ # else:
+ # # last batches are partial, eval is done on single (master) node
+ # num_batches = math.ceil(num_samples / args.batch_size)
+
+ # add meta-data to dataloader instance for convenience
+ dataloader.num_batches = num_batches
+ dataloader.num_samples = num_samples
+
+ return DataInfo(dataloader, None)
+
+
+def wds_batch_list2dict(
+ batch,
+ keys=[
+ "__url__",
+ "__key__",
+ "waveform",
+ "text",
+ "raw_text",
+ "audio_name",
+ "text_name",
+ "audio_orig_sr",
+ ],
+):
+ """
+ Return a dictionary of the batch, with keys as the names of the fields.
+ """
+ assert len(keys) == len(
+ batch
+ ), "batch must have same number of keys as keys argument"
+ return {keys[i]: batch[i] for i in range(len(batch))}
+
+
+def get_csv_dataset(args, preprocess_fn, is_train):
+ input_filename = args.train_data if is_train else args.val_data
+ assert input_filename
+ dataset = CsvDataset(
+ input_filename,
+ preprocess_fn,
+ img_key=args.csv_img_key,
+ caption_key=args.csv_caption_key,
+ sep=args.csv_separator,
+ )
+ num_samples = len(dataset)
+ sampler = DistributedSampler(dataset) if args.distributed and is_train else None
+ shuffle = is_train and sampler is None
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ shuffle=shuffle,
+ num_workers=args.workers,
+ pin_memory=True,
+ sampler=sampler,
+ drop_last=is_train,
+ )
+ dataloader.num_samples = num_samples
+ dataloader.num_batches = len(dataloader)
+
+ return DataInfo(dataloader, sampler)
+
+
+def get_toy_dataset(args, model_cfg, is_train):
+ index_path = args.train_data if is_train else args.val_data
+ ipc_path = args.train_ipc if is_train else args.val_ipc
+ assert index_path and ipc_path
+ eval_mode = not is_train
+ dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode)
+
+ num_samples = len(dataset)
+ sampler = (
+ DistributedSampler(dataset, shuffle=False)
+ if args.distributed and is_train
+ else None
+ )
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ shuffle=False,
+ num_workers=args.workers,
+ sampler=sampler,
+ drop_last=is_train,
+ )
+ dataloader.num_samples = num_samples
+ dataloader.num_batches = len(dataloader)
+
+ return DataInfo(dataloader, sampler)
+
+
+def get_dataset_fn(data_path, dataset_type):
+ if dataset_type == "webdataset":
+ return get_wds_dataset
+ elif dataset_type == "csv":
+ return get_csv_dataset
+ elif dataset_type == "auto":
+ ext = data_path.split(".")[-1]
+ if ext in ["csv", "tsv"]:
+ return get_csv_dataset
+ elif ext in ["tar"]:
+ return get_wds_dataset
+ else:
+ raise ValueError(
+ f"Tried to figure out dataset type, but failed for extention {ext}."
+ )
+ elif dataset_type == "toy":
+ return get_toy_dataset
+ else:
+ raise ValueError(f"Unsupported dataset type: {dataset_type}")
+
+
+def get_data(args, model_cfg):
+ data = {}
+
+ args.class_index_dict = load_class_label(args.class_label_path)
+
+ if args.datasetinfos is None:
+ args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
+ if args.dataset_type == "webdataset":
+ args.train_data = get_tar_path_from_dataset_name(
+ args.datasetnames,
+ args.datasetinfos,
+ islocal=not args.remotedata,
+ proportion=args.dataset_proportion,
+ dataset_path=args.datasetpath,
+ full_dataset=args.full_train_dataset,
+ )
+
+ if args.full_train_dataset is None:
+ args.full_train_dataset = []
+ if args.exclude_eval_dataset is None:
+ args.exclude_eval_dataset = []
+ excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset
+
+ val_dataset_names = (
+ [n for n in args.datasetnames if n not in excluded_eval_datasets]
+ if excluded_eval_datasets
+ else args.datasetnames
+ )
+ args.val_dataset_names = val_dataset_names
+ args.val_data = get_tar_path_from_dataset_name(
+ val_dataset_names,
+ ["valid", "test", "eval"],
+ islocal=not args.remotedata,
+ proportion=1,
+ dataset_path=args.datasetpath,
+ full_dataset=None,
+ )
+
+ if args.train_data:
+ data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
+ args, model_cfg, is_train=True
+ )
+
+ if args.val_data:
+ data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
+ args, model_cfg, is_train=False
+ )
+
+ return data
diff --git a/audioldm/clap/training/distributed.py b/audioldm/clap/training/distributed.py
new file mode 100755
index 0000000000000000000000000000000000000000..2fa61f76c5cc3ab9f6a9643042afa8e1f2e1cb7f
--- /dev/null
+++ b/audioldm/clap/training/distributed.py
@@ -0,0 +1,150 @@
+import os
+
+import torch
+import socket
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+
+def is_global_master(args):
+ return args.rank == 0
+
+
+def is_local_master(args):
+ return args.local_rank == 0
+
+
+def is_master(args, local=False):
+ return is_local_master(args) if local else is_global_master(args)
+
+
+def is_using_horovod():
+ # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
+ # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
+ ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
+ pmi_vars = ["PMI_RANK", "PMI_SIZE"]
+ if all([var in os.environ for var in ompi_vars]) or all(
+ [var in os.environ for var in pmi_vars]
+ ):
+ return True
+ else:
+ return False
+
+
+def is_using_distributed():
+ if "WORLD_SIZE" in os.environ:
+ return int(os.environ["WORLD_SIZE"]) > 1
+ if "SLURM_NTASKS" in os.environ:
+ return int(os.environ["SLURM_NTASKS"]) > 1
+ return False
+
+
+def world_info_from_env():
+ local_rank = 0
+ for v in (
+ "SLURM_LOCALID",
+ "MPI_LOCALRANKID",
+ "OMPI_COMM_WORLD_LOCAL_RANK",
+ "LOCAL_RANK",
+ ):
+ if v in os.environ:
+ local_rank = int(os.environ[v])
+ break
+ global_rank = 0
+ for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"):
+ if v in os.environ:
+ global_rank = int(os.environ[v])
+ break
+ world_size = 1
+ for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"):
+ if v in os.environ:
+ world_size = int(os.environ[v])
+ break
+
+ return local_rank, global_rank, world_size
+
+
+def init_distributed_device(args):
+ # Distributed training = training on more than one GPU.
+ # Works in both single and multi-node scenarios.
+ args.distributed = False
+ args.world_size = 1
+ args.rank = 0 # global rank
+ args.local_rank = 0
+ if args.horovod:
+ assert hvd is not None, "Horovod is not installed"
+ hvd.init()
+ world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
+ world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
+ local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
+ args.local_rank = local_rank
+ args.rank = world_rank
+ args.world_size = world_size
+ # args.local_rank = int(hvd.local_rank())
+ # args.rank = hvd.rank()
+ # args.world_size = hvd.size()
+ args.distributed = True
+ os.environ["LOCAL_RANK"] = str(args.local_rank)
+ os.environ["RANK"] = str(args.rank)
+ os.environ["WORLD_SIZE"] = str(args.world_size)
+ print(
+ f"Distributed training: local_rank={args.local_rank}, "
+ f"rank={args.rank}, world_size={args.world_size}, "
+ f"hostname={socket.gethostname()}, pid={os.getpid()}"
+ )
+ elif is_using_distributed():
+ if "SLURM_PROCID" in os.environ:
+ # DDP via SLURM
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
+ # SLURM var -> torch.distributed vars in case needed
+ os.environ["LOCAL_RANK"] = str(args.local_rank)
+ os.environ["RANK"] = str(args.rank)
+ os.environ["WORLD_SIZE"] = str(args.world_size)
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ )
+ elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster
+ world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
+ world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
+ local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
+ args.local_rank = local_rank
+ args.rank = world_rank
+ args.world_size = world_size
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ )
+ else:
+ # DDP via torchrun, torch.distributed.launch
+ args.local_rank, _, _ = world_info_from_env()
+ torch.distributed.init_process_group(
+ backend=args.dist_backend, init_method=args.dist_url
+ )
+ args.world_size = torch.distributed.get_world_size()
+ args.rank = torch.distributed.get_rank()
+ args.distributed = True
+ print(
+ f"Distributed training: local_rank={args.local_rank}, "
+ f"rank={args.rank}, world_size={args.world_size}, "
+ f"hostname={socket.gethostname()}, pid={os.getpid()}"
+ )
+
+ if torch.cuda.is_available():
+ if args.distributed and not args.no_set_device_rank:
+ device = "cuda:%d" % args.local_rank
+ else:
+ device = "cuda:0"
+ torch.cuda.set_device(device)
+ else:
+ device = "cpu"
+ args.device = device
+ device = torch.device(device)
+ return device
diff --git a/audioldm/clap/training/imagenet_zeroshot_data.py b/audioldm/clap/training/imagenet_zeroshot_data.py
new file mode 100755
index 0000000000000000000000000000000000000000..d32e55328d6799ccb8d61625f43abb80a33d6c17
--- /dev/null
+++ b/audioldm/clap/training/imagenet_zeroshot_data.py
@@ -0,0 +1,1088 @@
+# NOTE: This script is currently not supported for CLAP.
+
+imagenet_classnames = [
+ "tench",
+ "goldfish",
+ "great white shark",
+ "tiger shark",
+ "hammerhead shark",
+ "electric ray",
+ "stingray",
+ "rooster",
+ "hen",
+ "ostrich",
+ "brambling",
+ "goldfinch",
+ "house finch",
+ "junco",
+ "indigo bunting",
+ "American robin",
+ "bulbul",
+ "jay",
+ "magpie",
+ "chickadee",
+ "American dipper",
+ "kite (bird of prey)",
+ "bald eagle",
+ "vulture",
+ "great grey owl",
+ "fire salamander",
+ "smooth newt",
+ "newt",
+ "spotted salamander",
+ "axolotl",
+ "American bullfrog",
+ "tree frog",
+ "tailed frog",
+ "loggerhead sea turtle",
+ "leatherback sea turtle",
+ "mud turtle",
+ "terrapin",
+ "box turtle",
+ "banded gecko",
+ "green iguana",
+ "Carolina anole",
+ "desert grassland whiptail lizard",
+ "agama",
+ "frilled-necked lizard",
+ "alligator lizard",
+ "Gila monster",
+ "European green lizard",
+ "chameleon",
+ "Komodo dragon",
+ "Nile crocodile",
+ "American alligator",
+ "triceratops",
+ "worm snake",
+ "ring-necked snake",
+ "eastern hog-nosed snake",
+ "smooth green snake",
+ "kingsnake",
+ "garter snake",
+ "water snake",
+ "vine snake",
+ "night snake",
+ "boa constrictor",
+ "African rock python",
+ "Indian cobra",
+ "green mamba",
+ "sea snake",
+ "Saharan horned viper",
+ "eastern diamondback rattlesnake",
+ "sidewinder rattlesnake",
+ "trilobite",
+ "harvestman",
+ "scorpion",
+ "yellow garden spider",
+ "barn spider",
+ "European garden spider",
+ "southern black widow",
+ "tarantula",
+ "wolf spider",
+ "tick",
+ "centipede",
+ "black grouse",
+ "ptarmigan",
+ "ruffed grouse",
+ "prairie grouse",
+ "peafowl",
+ "quail",
+ "partridge",
+ "african grey parrot",
+ "macaw",
+ "sulphur-crested cockatoo",
+ "lorikeet",
+ "coucal",
+ "bee eater",
+ "hornbill",
+ "hummingbird",
+ "jacamar",
+ "toucan",
+ "duck",
+ "red-breasted merganser",
+ "goose",
+ "black swan",
+ "tusker",
+ "echidna",
+ "platypus",
+ "wallaby",
+ "koala",
+ "wombat",
+ "jellyfish",
+ "sea anemone",
+ "brain coral",
+ "flatworm",
+ "nematode",
+ "conch",
+ "snail",
+ "slug",
+ "sea slug",
+ "chiton",
+ "chambered nautilus",
+ "Dungeness crab",
+ "rock crab",
+ "fiddler crab",
+ "red king crab",
+ "American lobster",
+ "spiny lobster",
+ "crayfish",
+ "hermit crab",
+ "isopod",
+ "white stork",
+ "black stork",
+ "spoonbill",
+ "flamingo",
+ "little blue heron",
+ "great egret",
+ "bittern bird",
+ "crane bird",
+ "limpkin",
+ "common gallinule",
+ "American coot",
+ "bustard",
+ "ruddy turnstone",
+ "dunlin",
+ "common redshank",
+ "dowitcher",
+ "oystercatcher",
+ "pelican",
+ "king penguin",
+ "albatross",
+ "grey whale",
+ "killer whale",
+ "dugong",
+ "sea lion",
+ "Chihuahua",
+ "Japanese Chin",
+ "Maltese",
+ "Pekingese",
+ "Shih Tzu",
+ "King Charles Spaniel",
+ "Papillon",
+ "toy terrier",
+ "Rhodesian Ridgeback",
+ "Afghan Hound",
+ "Basset Hound",
+ "Beagle",
+ "Bloodhound",
+ "Bluetick Coonhound",
+ "Black and Tan Coonhound",
+ "Treeing Walker Coonhound",
+ "English foxhound",
+ "Redbone Coonhound",
+ "borzoi",
+ "Irish Wolfhound",
+ "Italian Greyhound",
+ "Whippet",
+ "Ibizan Hound",
+ "Norwegian Elkhound",
+ "Otterhound",
+ "Saluki",
+ "Scottish Deerhound",
+ "Weimaraner",
+ "Staffordshire Bull Terrier",
+ "American Staffordshire Terrier",
+ "Bedlington Terrier",
+ "Border Terrier",
+ "Kerry Blue Terrier",
+ "Irish Terrier",
+ "Norfolk Terrier",
+ "Norwich Terrier",
+ "Yorkshire Terrier",
+ "Wire Fox Terrier",
+ "Lakeland Terrier",
+ "Sealyham Terrier",
+ "Airedale Terrier",
+ "Cairn Terrier",
+ "Australian Terrier",
+ "Dandie Dinmont Terrier",
+ "Boston Terrier",
+ "Miniature Schnauzer",
+ "Giant Schnauzer",
+ "Standard Schnauzer",
+ "Scottish Terrier",
+ "Tibetan Terrier",
+ "Australian Silky Terrier",
+ "Soft-coated Wheaten Terrier",
+ "West Highland White Terrier",
+ "Lhasa Apso",
+ "Flat-Coated Retriever",
+ "Curly-coated Retriever",
+ "Golden Retriever",
+ "Labrador Retriever",
+ "Chesapeake Bay Retriever",
+ "German Shorthaired Pointer",
+ "Vizsla",
+ "English Setter",
+ "Irish Setter",
+ "Gordon Setter",
+ "Brittany dog",
+ "Clumber Spaniel",
+ "English Springer Spaniel",
+ "Welsh Springer Spaniel",
+ "Cocker Spaniel",
+ "Sussex Spaniel",
+ "Irish Water Spaniel",
+ "Kuvasz",
+ "Schipperke",
+ "Groenendael dog",
+ "Malinois",
+ "Briard",
+ "Australian Kelpie",
+ "Komondor",
+ "Old English Sheepdog",
+ "Shetland Sheepdog",
+ "collie",
+ "Border Collie",
+ "Bouvier des Flandres dog",
+ "Rottweiler",
+ "German Shepherd Dog",
+ "Dobermann",
+ "Miniature Pinscher",
+ "Greater Swiss Mountain Dog",
+ "Bernese Mountain Dog",
+ "Appenzeller Sennenhund",
+ "Entlebucher Sennenhund",
+ "Boxer",
+ "Bullmastiff",
+ "Tibetan Mastiff",
+ "French Bulldog",
+ "Great Dane",
+ "St. Bernard",
+ "husky",
+ "Alaskan Malamute",
+ "Siberian Husky",
+ "Dalmatian",
+ "Affenpinscher",
+ "Basenji",
+ "pug",
+ "Leonberger",
+ "Newfoundland dog",
+ "Great Pyrenees dog",
+ "Samoyed",
+ "Pomeranian",
+ "Chow Chow",
+ "Keeshond",
+ "brussels griffon",
+ "Pembroke Welsh Corgi",
+ "Cardigan Welsh Corgi",
+ "Toy Poodle",
+ "Miniature Poodle",
+ "Standard Poodle",
+ "Mexican hairless dog (xoloitzcuintli)",
+ "grey wolf",
+ "Alaskan tundra wolf",
+ "red wolf or maned wolf",
+ "coyote",
+ "dingo",
+ "dhole",
+ "African wild dog",
+ "hyena",
+ "red fox",
+ "kit fox",
+ "Arctic fox",
+ "grey fox",
+ "tabby cat",
+ "tiger cat",
+ "Persian cat",
+ "Siamese cat",
+ "Egyptian Mau",
+ "cougar",
+ "lynx",
+ "leopard",
+ "snow leopard",
+ "jaguar",
+ "lion",
+ "tiger",
+ "cheetah",
+ "brown bear",
+ "American black bear",
+ "polar bear",
+ "sloth bear",
+ "mongoose",
+ "meerkat",
+ "tiger beetle",
+ "ladybug",
+ "ground beetle",
+ "longhorn beetle",
+ "leaf beetle",
+ "dung beetle",
+ "rhinoceros beetle",
+ "weevil",
+ "fly",
+ "bee",
+ "ant",
+ "grasshopper",
+ "cricket insect",
+ "stick insect",
+ "cockroach",
+ "praying mantis",
+ "cicada",
+ "leafhopper",
+ "lacewing",
+ "dragonfly",
+ "damselfly",
+ "red admiral butterfly",
+ "ringlet butterfly",
+ "monarch butterfly",
+ "small white butterfly",
+ "sulphur butterfly",
+ "gossamer-winged butterfly",
+ "starfish",
+ "sea urchin",
+ "sea cucumber",
+ "cottontail rabbit",
+ "hare",
+ "Angora rabbit",
+ "hamster",
+ "porcupine",
+ "fox squirrel",
+ "marmot",
+ "beaver",
+ "guinea pig",
+ "common sorrel horse",
+ "zebra",
+ "pig",
+ "wild boar",
+ "warthog",
+ "hippopotamus",
+ "ox",
+ "water buffalo",
+ "bison",
+ "ram (adult male sheep)",
+ "bighorn sheep",
+ "Alpine ibex",
+ "hartebeest",
+ "impala (antelope)",
+ "gazelle",
+ "arabian camel",
+ "llama",
+ "weasel",
+ "mink",
+ "European polecat",
+ "black-footed ferret",
+ "otter",
+ "skunk",
+ "badger",
+ "armadillo",
+ "three-toed sloth",
+ "orangutan",
+ "gorilla",
+ "chimpanzee",
+ "gibbon",
+ "siamang",
+ "guenon",
+ "patas monkey",
+ "baboon",
+ "macaque",
+ "langur",
+ "black-and-white colobus",
+ "proboscis monkey",
+ "marmoset",
+ "white-headed capuchin",
+ "howler monkey",
+ "titi monkey",
+ "Geoffroy's spider monkey",
+ "common squirrel monkey",
+ "ring-tailed lemur",
+ "indri",
+ "Asian elephant",
+ "African bush elephant",
+ "red panda",
+ "giant panda",
+ "snoek fish",
+ "eel",
+ "silver salmon",
+ "rock beauty fish",
+ "clownfish",
+ "sturgeon",
+ "gar fish",
+ "lionfish",
+ "pufferfish",
+ "abacus",
+ "abaya",
+ "academic gown",
+ "accordion",
+ "acoustic guitar",
+ "aircraft carrier",
+ "airliner",
+ "airship",
+ "altar",
+ "ambulance",
+ "amphibious vehicle",
+ "analog clock",
+ "apiary",
+ "apron",
+ "trash can",
+ "assault rifle",
+ "backpack",
+ "bakery",
+ "balance beam",
+ "balloon",
+ "ballpoint pen",
+ "Band-Aid",
+ "banjo",
+ "baluster / handrail",
+ "barbell",
+ "barber chair",
+ "barbershop",
+ "barn",
+ "barometer",
+ "barrel",
+ "wheelbarrow",
+ "baseball",
+ "basketball",
+ "bassinet",
+ "bassoon",
+ "swimming cap",
+ "bath towel",
+ "bathtub",
+ "station wagon",
+ "lighthouse",
+ "beaker",
+ "military hat (bearskin or shako)",
+ "beer bottle",
+ "beer glass",
+ "bell tower",
+ "baby bib",
+ "tandem bicycle",
+ "bikini",
+ "ring binder",
+ "binoculars",
+ "birdhouse",
+ "boathouse",
+ "bobsleigh",
+ "bolo tie",
+ "poke bonnet",
+ "bookcase",
+ "bookstore",
+ "bottle cap",
+ "hunting bow",
+ "bow tie",
+ "brass memorial plaque",
+ "bra",
+ "breakwater",
+ "breastplate",
+ "broom",
+ "bucket",
+ "buckle",
+ "bulletproof vest",
+ "high-speed train",
+ "butcher shop",
+ "taxicab",
+ "cauldron",
+ "candle",
+ "cannon",
+ "canoe",
+ "can opener",
+ "cardigan",
+ "car mirror",
+ "carousel",
+ "tool kit",
+ "cardboard box / carton",
+ "car wheel",
+ "automated teller machine",
+ "cassette",
+ "cassette player",
+ "castle",
+ "catamaran",
+ "CD player",
+ "cello",
+ "mobile phone",
+ "chain",
+ "chain-link fence",
+ "chain mail",
+ "chainsaw",
+ "storage chest",
+ "chiffonier",
+ "bell or wind chime",
+ "china cabinet",
+ "Christmas stocking",
+ "church",
+ "movie theater",
+ "cleaver",
+ "cliff dwelling",
+ "cloak",
+ "clogs",
+ "cocktail shaker",
+ "coffee mug",
+ "coffeemaker",
+ "spiral or coil",
+ "combination lock",
+ "computer keyboard",
+ "candy store",
+ "container ship",
+ "convertible",
+ "corkscrew",
+ "cornet",
+ "cowboy boot",
+ "cowboy hat",
+ "cradle",
+ "construction crane",
+ "crash helmet",
+ "crate",
+ "infant bed",
+ "Crock Pot",
+ "croquet ball",
+ "crutch",
+ "cuirass",
+ "dam",
+ "desk",
+ "desktop computer",
+ "rotary dial telephone",
+ "diaper",
+ "digital clock",
+ "digital watch",
+ "dining table",
+ "dishcloth",
+ "dishwasher",
+ "disc brake",
+ "dock",
+ "dog sled",
+ "dome",
+ "doormat",
+ "drilling rig",
+ "drum",
+ "drumstick",
+ "dumbbell",
+ "Dutch oven",
+ "electric fan",
+ "electric guitar",
+ "electric locomotive",
+ "entertainment center",
+ "envelope",
+ "espresso machine",
+ "face powder",
+ "feather boa",
+ "filing cabinet",
+ "fireboat",
+ "fire truck",
+ "fire screen",
+ "flagpole",
+ "flute",
+ "folding chair",
+ "football helmet",
+ "forklift",
+ "fountain",
+ "fountain pen",
+ "four-poster bed",
+ "freight car",
+ "French horn",
+ "frying pan",
+ "fur coat",
+ "garbage truck",
+ "gas mask or respirator",
+ "gas pump",
+ "goblet",
+ "go-kart",
+ "golf ball",
+ "golf cart",
+ "gondola",
+ "gong",
+ "gown",
+ "grand piano",
+ "greenhouse",
+ "radiator grille",
+ "grocery store",
+ "guillotine",
+ "hair clip",
+ "hair spray",
+ "half-track",
+ "hammer",
+ "hamper",
+ "hair dryer",
+ "hand-held computer",
+ "handkerchief",
+ "hard disk drive",
+ "harmonica",
+ "harp",
+ "combine harvester",
+ "hatchet",
+ "holster",
+ "home theater",
+ "honeycomb",
+ "hook",
+ "hoop skirt",
+ "gymnastic horizontal bar",
+ "horse-drawn vehicle",
+ "hourglass",
+ "iPod",
+ "clothes iron",
+ "carved pumpkin",
+ "jeans",
+ "jeep",
+ "T-shirt",
+ "jigsaw puzzle",
+ "rickshaw",
+ "joystick",
+ "kimono",
+ "knee pad",
+ "knot",
+ "lab coat",
+ "ladle",
+ "lampshade",
+ "laptop computer",
+ "lawn mower",
+ "lens cap",
+ "letter opener",
+ "library",
+ "lifeboat",
+ "lighter",
+ "limousine",
+ "ocean liner",
+ "lipstick",
+ "slip-on shoe",
+ "lotion",
+ "music speaker",
+ "loupe magnifying glass",
+ "sawmill",
+ "magnetic compass",
+ "messenger bag",
+ "mailbox",
+ "tights",
+ "one-piece bathing suit",
+ "manhole cover",
+ "maraca",
+ "marimba",
+ "mask",
+ "matchstick",
+ "maypole",
+ "maze",
+ "measuring cup",
+ "medicine cabinet",
+ "megalith",
+ "microphone",
+ "microwave oven",
+ "military uniform",
+ "milk can",
+ "minibus",
+ "miniskirt",
+ "minivan",
+ "missile",
+ "mitten",
+ "mixing bowl",
+ "mobile home",
+ "ford model t",
+ "modem",
+ "monastery",
+ "monitor",
+ "moped",
+ "mortar and pestle",
+ "graduation cap",
+ "mosque",
+ "mosquito net",
+ "vespa",
+ "mountain bike",
+ "tent",
+ "computer mouse",
+ "mousetrap",
+ "moving van",
+ "muzzle",
+ "metal nail",
+ "neck brace",
+ "necklace",
+ "baby pacifier",
+ "notebook computer",
+ "obelisk",
+ "oboe",
+ "ocarina",
+ "odometer",
+ "oil filter",
+ "pipe organ",
+ "oscilloscope",
+ "overskirt",
+ "bullock cart",
+ "oxygen mask",
+ "product packet / packaging",
+ "paddle",
+ "paddle wheel",
+ "padlock",
+ "paintbrush",
+ "pajamas",
+ "palace",
+ "pan flute",
+ "paper towel",
+ "parachute",
+ "parallel bars",
+ "park bench",
+ "parking meter",
+ "railroad car",
+ "patio",
+ "payphone",
+ "pedestal",
+ "pencil case",
+ "pencil sharpener",
+ "perfume",
+ "Petri dish",
+ "photocopier",
+ "plectrum",
+ "Pickelhaube",
+ "picket fence",
+ "pickup truck",
+ "pier",
+ "piggy bank",
+ "pill bottle",
+ "pillow",
+ "ping-pong ball",
+ "pinwheel",
+ "pirate ship",
+ "drink pitcher",
+ "block plane",
+ "planetarium",
+ "plastic bag",
+ "plate rack",
+ "farm plow",
+ "plunger",
+ "Polaroid camera",
+ "pole",
+ "police van",
+ "poncho",
+ "pool table",
+ "soda bottle",
+ "plant pot",
+ "potter's wheel",
+ "power drill",
+ "prayer rug",
+ "printer",
+ "prison",
+ "missile",
+ "projector",
+ "hockey puck",
+ "punching bag",
+ "purse",
+ "quill",
+ "quilt",
+ "race car",
+ "racket",
+ "radiator",
+ "radio",
+ "radio telescope",
+ "rain barrel",
+ "recreational vehicle",
+ "fishing casting reel",
+ "reflex camera",
+ "refrigerator",
+ "remote control",
+ "restaurant",
+ "revolver",
+ "rifle",
+ "rocking chair",
+ "rotisserie",
+ "eraser",
+ "rugby ball",
+ "ruler measuring stick",
+ "sneaker",
+ "safe",
+ "safety pin",
+ "salt shaker",
+ "sandal",
+ "sarong",
+ "saxophone",
+ "scabbard",
+ "weighing scale",
+ "school bus",
+ "schooner",
+ "scoreboard",
+ "CRT monitor",
+ "screw",
+ "screwdriver",
+ "seat belt",
+ "sewing machine",
+ "shield",
+ "shoe store",
+ "shoji screen / room divider",
+ "shopping basket",
+ "shopping cart",
+ "shovel",
+ "shower cap",
+ "shower curtain",
+ "ski",
+ "balaclava ski mask",
+ "sleeping bag",
+ "slide rule",
+ "sliding door",
+ "slot machine",
+ "snorkel",
+ "snowmobile",
+ "snowplow",
+ "soap dispenser",
+ "soccer ball",
+ "sock",
+ "solar thermal collector",
+ "sombrero",
+ "soup bowl",
+ "keyboard space bar",
+ "space heater",
+ "space shuttle",
+ "spatula",
+ "motorboat",
+ "spider web",
+ "spindle",
+ "sports car",
+ "spotlight",
+ "stage",
+ "steam locomotive",
+ "through arch bridge",
+ "steel drum",
+ "stethoscope",
+ "scarf",
+ "stone wall",
+ "stopwatch",
+ "stove",
+ "strainer",
+ "tram",
+ "stretcher",
+ "couch",
+ "stupa",
+ "submarine",
+ "suit",
+ "sundial",
+ "sunglasses",
+ "sunglasses",
+ "sunscreen",
+ "suspension bridge",
+ "mop",
+ "sweatshirt",
+ "swim trunks / shorts",
+ "swing",
+ "electrical switch",
+ "syringe",
+ "table lamp",
+ "tank",
+ "tape player",
+ "teapot",
+ "teddy bear",
+ "television",
+ "tennis ball",
+ "thatched roof",
+ "front curtain",
+ "thimble",
+ "threshing machine",
+ "throne",
+ "tile roof",
+ "toaster",
+ "tobacco shop",
+ "toilet seat",
+ "torch",
+ "totem pole",
+ "tow truck",
+ "toy store",
+ "tractor",
+ "semi-trailer truck",
+ "tray",
+ "trench coat",
+ "tricycle",
+ "trimaran",
+ "tripod",
+ "triumphal arch",
+ "trolleybus",
+ "trombone",
+ "hot tub",
+ "turnstile",
+ "typewriter keyboard",
+ "umbrella",
+ "unicycle",
+ "upright piano",
+ "vacuum cleaner",
+ "vase",
+ "vaulted or arched ceiling",
+ "velvet fabric",
+ "vending machine",
+ "vestment",
+ "viaduct",
+ "violin",
+ "volleyball",
+ "waffle iron",
+ "wall clock",
+ "wallet",
+ "wardrobe",
+ "military aircraft",
+ "sink",
+ "washing machine",
+ "water bottle",
+ "water jug",
+ "water tower",
+ "whiskey jug",
+ "whistle",
+ "hair wig",
+ "window screen",
+ "window shade",
+ "Windsor tie",
+ "wine bottle",
+ "airplane wing",
+ "wok",
+ "wooden spoon",
+ "wool",
+ "split-rail fence",
+ "shipwreck",
+ "sailboat",
+ "yurt",
+ "website",
+ "comic book",
+ "crossword",
+ "traffic or street sign",
+ "traffic light",
+ "dust jacket",
+ "menu",
+ "plate",
+ "guacamole",
+ "consomme",
+ "hot pot",
+ "trifle",
+ "ice cream",
+ "popsicle",
+ "baguette",
+ "bagel",
+ "pretzel",
+ "cheeseburger",
+ "hot dog",
+ "mashed potatoes",
+ "cabbage",
+ "broccoli",
+ "cauliflower",
+ "zucchini",
+ "spaghetti squash",
+ "acorn squash",
+ "butternut squash",
+ "cucumber",
+ "artichoke",
+ "bell pepper",
+ "cardoon",
+ "mushroom",
+ "Granny Smith apple",
+ "strawberry",
+ "orange",
+ "lemon",
+ "fig",
+ "pineapple",
+ "banana",
+ "jackfruit",
+ "cherimoya (custard apple)",
+ "pomegranate",
+ "hay",
+ "carbonara",
+ "chocolate syrup",
+ "dough",
+ "meatloaf",
+ "pizza",
+ "pot pie",
+ "burrito",
+ "red wine",
+ "espresso",
+ "tea cup",
+ "eggnog",
+ "mountain",
+ "bubble",
+ "cliff",
+ "coral reef",
+ "geyser",
+ "lakeshore",
+ "promontory",
+ "sandbar",
+ "beach",
+ "valley",
+ "volcano",
+ "baseball player",
+ "bridegroom",
+ "scuba diver",
+ "rapeseed",
+ "daisy",
+ "yellow lady's slipper",
+ "corn",
+ "acorn",
+ "rose hip",
+ "horse chestnut seed",
+ "coral fungus",
+ "agaric",
+ "gyromitra",
+ "stinkhorn mushroom",
+ "earth star fungus",
+ "hen of the woods mushroom",
+ "bolete",
+ "corn cob",
+ "toilet paper",
+]
+
+
+openai_imagenet_template = [
+ lambda c: f"a bad photo of a {c}.",
+ lambda c: f"a photo of many {c}.",
+ lambda c: f"a sculpture of a {c}.",
+ lambda c: f"a photo of the hard to see {c}.",
+ lambda c: f"a low resolution photo of the {c}.",
+ lambda c: f"a rendering of a {c}.",
+ lambda c: f"graffiti of a {c}.",
+ lambda c: f"a bad photo of the {c}.",
+ lambda c: f"a cropped photo of the {c}.",
+ lambda c: f"a tattoo of a {c}.",
+ lambda c: f"the embroidered {c}.",
+ lambda c: f"a photo of a hard to see {c}.",
+ lambda c: f"a bright photo of a {c}.",
+ lambda c: f"a photo of a clean {c}.",
+ lambda c: f"a photo of a dirty {c}.",
+ lambda c: f"a dark photo of the {c}.",
+ lambda c: f"a drawing of a {c}.",
+ lambda c: f"a photo of my {c}.",
+ lambda c: f"the plastic {c}.",
+ lambda c: f"a photo of the cool {c}.",
+ lambda c: f"a close-up photo of a {c}.",
+ lambda c: f"a black and white photo of the {c}.",
+ lambda c: f"a painting of the {c}.",
+ lambda c: f"a painting of a {c}.",
+ lambda c: f"a pixelated photo of the {c}.",
+ lambda c: f"a sculpture of the {c}.",
+ lambda c: f"a bright photo of the {c}.",
+ lambda c: f"a cropped photo of a {c}.",
+ lambda c: f"a plastic {c}.",
+ lambda c: f"a photo of the dirty {c}.",
+ lambda c: f"a jpeg corrupted photo of a {c}.",
+ lambda c: f"a blurry photo of the {c}.",
+ lambda c: f"a photo of the {c}.",
+ lambda c: f"a good photo of the {c}.",
+ lambda c: f"a rendering of the {c}.",
+ lambda c: f"a {c} in a video game.",
+ lambda c: f"a photo of one {c}.",
+ lambda c: f"a doodle of a {c}.",
+ lambda c: f"a close-up photo of the {c}.",
+ lambda c: f"a photo of a {c}.",
+ lambda c: f"the origami {c}.",
+ lambda c: f"the {c} in a video game.",
+ lambda c: f"a sketch of a {c}.",
+ lambda c: f"a doodle of the {c}.",
+ lambda c: f"a origami {c}.",
+ lambda c: f"a low resolution photo of a {c}.",
+ lambda c: f"the toy {c}.",
+ lambda c: f"a rendition of the {c}.",
+ lambda c: f"a photo of the clean {c}.",
+ lambda c: f"a photo of a large {c}.",
+ lambda c: f"a rendition of a {c}.",
+ lambda c: f"a photo of a nice {c}.",
+ lambda c: f"a photo of a weird {c}.",
+ lambda c: f"a blurry photo of a {c}.",
+ lambda c: f"a cartoon {c}.",
+ lambda c: f"art of a {c}.",
+ lambda c: f"a sketch of the {c}.",
+ lambda c: f"a embroidered {c}.",
+ lambda c: f"a pixelated photo of a {c}.",
+ lambda c: f"itap of the {c}.",
+ lambda c: f"a jpeg corrupted photo of the {c}.",
+ lambda c: f"a good photo of a {c}.",
+ lambda c: f"a plushie {c}.",
+ lambda c: f"a photo of the nice {c}.",
+ lambda c: f"a photo of the small {c}.",
+ lambda c: f"a photo of the weird {c}.",
+ lambda c: f"the cartoon {c}.",
+ lambda c: f"art of the {c}.",
+ lambda c: f"a drawing of the {c}.",
+ lambda c: f"a photo of the large {c}.",
+ lambda c: f"a black and white photo of a {c}.",
+ lambda c: f"the plushie {c}.",
+ lambda c: f"a dark photo of a {c}.",
+ lambda c: f"itap of a {c}.",
+ lambda c: f"graffiti of the {c}.",
+ lambda c: f"a toy {c}.",
+ lambda c: f"itap of my {c}.",
+ lambda c: f"a photo of a cool {c}.",
+ lambda c: f"a photo of a small {c}.",
+ lambda c: f"a tattoo of the {c}.",
+]
diff --git a/audioldm/clap/training/infer_demo.py b/audioldm/clap/training/infer_demo.py
new file mode 100755
index 0000000000000000000000000000000000000000..7d1f4784898dbfeb69affefb6f624711adc8cb42
--- /dev/null
+++ b/audioldm/clap/training/infer_demo.py
@@ -0,0 +1,105 @@
+import sys
+
+import os
+import torch
+import librosa
+from open_clip import create_model
+from training.data import get_audio_features
+from training.data import int16_to_float32, float32_to_int16
+from transformers import RobertaTokenizer
+
+tokenize = RobertaTokenizer.from_pretrained("roberta-base")
+
+
+def tokenizer(text):
+ result = tokenize(
+ text,
+ padding="max_length",
+ truncation=True,
+ max_length=77,
+ return_tensors="pt",
+ )
+ return {k: v.squeeze(0) for k, v in result.items()}
+
+
+PRETRAINED_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/checkpoints/epoch_top_0_audioset_no_fusion.pt"
+WAVE_48k_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/audio/machine.wav"
+
+
+def infer_text():
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ precision = "fp32"
+ amodel = "HTSAT-tiny" # or 'PANN-14'
+ tmodel = "roberta" # the best text encoder in our training
+ enable_fusion = False # False if you do not want to use the fusion model
+ fusion_type = "aff_2d"
+ pretrained = PRETRAINED_PATH
+
+ model, model_cfg = create_model(
+ amodel,
+ tmodel,
+ pretrained,
+ precision=precision,
+ device=device,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+ # load the text, can be a list (i.e. batch size)
+ text_data = ["I love the contrastive learning", "I love the pretrain model"]
+ # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90
+ text_data = tokenizer(text_data)
+
+ text_embed = model.get_text_embedding(text_data)
+ print(text_embed.size())
+
+
+def infer_audio():
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ precision = "fp32"
+ amodel = "HTSAT-tiny" # or 'PANN-14'
+ tmodel = "roberta" # the best text encoder in our training
+ enable_fusion = False # False if you do not want to use the fusion model
+ fusion_type = "aff_2d"
+ pretrained = PRETRAINED_PATH
+
+ model, model_cfg = create_model(
+ amodel,
+ tmodel,
+ pretrained,
+ precision=precision,
+ device=device,
+ enable_fusion=enable_fusion,
+ fusion_type=fusion_type,
+ )
+
+ # load the waveform of the shape (T,), should resample to 48000
+ audio_waveform, sr = librosa.load(WAVE_48k_PATH, sr=48000)
+ # quantize
+ audio_waveform = int16_to_float32(float32_to_int16(audio_waveform))
+ audio_waveform = torch.from_numpy(audio_waveform).float()
+ audio_dict = {}
+
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
+ import ipdb
+
+ ipdb.set_trace()
+ audio_dict = get_audio_features(
+ audio_dict,
+ audio_waveform,
+ 480000,
+ data_truncating="fusion",
+ data_filling="repeatpad",
+ audio_cfg=model_cfg["audio_cfg"],
+ )
+ # can send a list to the model, to process many audio tracks in one time (i.e. batch size)
+ audio_embed = model.get_audio_embedding([audio_dict])
+ print(audio_embed.size())
+ import ipdb
+
+ ipdb.set_trace()
+
+
+if __name__ == "__main__":
+ infer_text()
+ infer_audio()
diff --git a/audioldm/clap/training/logger.py b/audioldm/clap/training/logger.py
new file mode 100755
index 0000000000000000000000000000000000000000..ac4634970fae6aacde2b7b808355dbd50c90ce73
--- /dev/null
+++ b/audioldm/clap/training/logger.py
@@ -0,0 +1,30 @@
+import logging
+
+
+def setup_logging(log_file, level, include_host=False):
+ if include_host:
+ import socket
+
+ hostname = socket.gethostname()
+ formatter = logging.Formatter(
+ f"%(asctime)s | {hostname} | %(levelname)s | %(message)s",
+ datefmt="%Y-%m-%d,%H:%M:%S",
+ )
+ else:
+ formatter = logging.Formatter(
+ "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S"
+ )
+
+ logging.root.setLevel(level)
+ loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
+ for logger in loggers:
+ logger.setLevel(level)
+
+ stream_handler = logging.StreamHandler()
+ stream_handler.setFormatter(formatter)
+ logging.root.addHandler(stream_handler)
+
+ if log_file:
+ file_handler = logging.FileHandler(filename=log_file)
+ file_handler.setFormatter(formatter)
+ logging.root.addHandler(file_handler)
diff --git a/audioldm/clap/training/lp_main.py b/audioldm/clap/training/lp_main.py
new file mode 100755
index 0000000000000000000000000000000000000000..c2d4e8c85aaa3c8e4221963ef56a815cc14f354f
--- /dev/null
+++ b/audioldm/clap/training/lp_main.py
@@ -0,0 +1,670 @@
+from cmath import cos
+from inspect import getargs
+import logging
+import os
+import random
+from datetime import datetime
+import bisect
+import copy
+from sched import scheduler
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+from torch import optim
+from torch.cuda.amp import GradScaler
+import faulthandler
+import pathlib
+import argparse
+import time
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+try:
+ import torch.utils.tensorboard as tensorboard
+except ImportError:
+ tensorboard = None
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+from open_clip import create_model_and_transforms, trace_model, create_model
+from training.data import get_data
+from training.params import parse_args
+from training.distributed import is_master, init_distributed_device, world_info_from_env
+from training.logger import setup_logging
+from training.scheduler import cosine_lr
+from training.lp_train import train_one_epoch, evaluate
+from open_clip.utils import get_tar_path_from_dataset_name, dataset_split, get_optimizer
+from open_clip.utils import load_p, load_class_label
+from open_clip.linear_probe import LinearProbe
+
+
+def maintain_ckpts(args, startidx, all_idx_len):
+ for i in reversed(range(startidx, all_idx_len)):
+ if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")):
+ os.rename(
+ os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
+ os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"),
+ )
+ if os.path.exists(
+ os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")
+ ):
+ os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt"))
+ return
+
+
+def update_top_k_performance(
+ new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True
+):
+ """
+ Record the top-k performance of the current epoch.
+ current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...}
+ """
+ if isinstance(new_metrics_inputs, (list, tuple)):
+ new_metrics_inputs = np.mean(new_metrics_inputs)
+ return update_top_k_performance(
+ new_metrics_inputs,
+ current_top_k_ckpt_metrics,
+ args=args,
+ ckpt=ckpt,
+ bignumbetter=bignumbetter,
+ )
+ elif isinstance(new_metrics_inputs, dict):
+ new_metrics_inputs = np.mean(list(new_metrics_inputs.values()))
+ return update_top_k_performance(
+ new_metrics_inputs,
+ current_top_k_ckpt_metrics,
+ args=args,
+ ckpt=ckpt,
+ bignumbetter=bignumbetter,
+ )
+ elif isinstance(new_metrics_inputs, (float, int)):
+ update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()}
+ sorted_keys = sorted(current_top_k_ckpt_metrics.keys())
+ sorted_values = sorted(
+ current_top_k_ckpt_metrics.values(), reverse=bignumbetter
+ )
+ sorted_values_ = copy.deepcopy(sorted_values)
+ sorted_values.append(new_metrics_inputs)
+ sorted_values = sorted(sorted_values, reverse=bignumbetter)
+ sorted_values = sorted_values[:-1]
+
+ if sorted_values == sorted_values_:
+ return current_top_k_ckpt_metrics, new_metrics_inputs
+ else:
+ for i in range(len(sorted_keys)):
+ if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]:
+ current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i]
+ update_flag[sorted_keys[i]] = True
+ for i in range(len(update_flag)):
+ if update_flag[i]:
+ maintain_ckpts(args, i, len(sorted_keys))
+ torch.save(
+ ckpt,
+ os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
+ )
+ break
+ return current_top_k_ckpt_metrics, new_metrics_inputs
+
+
+# def updateifNone(a, b):
+# a = b if None else a
+# return a
+
+
+def is_pretrained_params(n):
+ return (
+ n.startswith("clap_model.transformer")
+ or n in ["clap_model.positional_embedding", "clap_model.text_projection"]
+ or n.startswith("clap_model.token_embedding")
+ or n.startswith("clap_model.ln_final")
+ or n.startswith("clap_model.logit_scale_t")
+ )
+
+
+def random_seed(seed=42, rank=0):
+ torch.manual_seed(seed + rank)
+ np.random.seed(seed + rank)
+ random.seed(seed + rank)
+
+
+def config_lp_optimizer(model, data, args):
+ # set wd-related params to 0 if use adam optimizer
+ if args.optimizer == "adam":
+ args.wd = 0
+ args.wd_pretrained = 0
+ args.wd_new = 0
+
+ in_clap = lambda n, p: n.startswith("clap_model")
+
+ named_parameters = list(model.named_parameters())
+
+ optimizer = {}
+ scheduler = {}
+
+ # freeze text encoder
+ text_freeze_parameters = [
+ p
+ for n, p in named_parameters
+ if n.startswith("clap_model.transformer")
+ or n in ["clap_model.positional_embedding", "clap_model.text_projection"]
+ or n.startswith("clap_model.token_embedding")
+ or n.startswith("clap_model.ln_final")
+ ]
+
+ if args.freeze_text:
+ logging.info("Freeze Text!!!!")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+
+ if not args.lp_freeze:
+ exclude = (
+ lambda n, p: p.ndim < 2
+ or "bn" in n
+ or "ln" in n
+ or "bias" in n
+ or "logit_scale" in n
+ )
+ include = lambda n, p: not exclude(n, p)
+
+ # (yusong): we do not split the learning rate anymore
+ # p for n, p in named_parameters if in_clap(n,p) and exclude(n, p) and p.requires_grad
+ gain_or_bias_params = [
+ p for n, p in named_parameters if exclude(n, p) and p.requires_grad
+ ]
+ # rest_params = [p for n, p in named_parameters if in_clap(n,p) and include(n, p) and p.requires_grad]
+ rest_params = [
+ p for n, p in named_parameters if include(n, p) and p.requires_grad
+ ]
+
+ if args.train_data is None:
+ optimizer = None
+ scheduler = None
+ else:
+ total_steps = data["train"].dataloader.num_batches * args.epochs
+
+ if args.split_opt:
+ for x in ["lr", "beta1", "beta2", "eps", "wd"]:
+ for y in ["_new", "_pretrained"]:
+ if getattr(args, x + y) is None:
+ setattr(args, x + y, getattr(args, x))
+
+ gain_or_bias_pretrained_params = [
+ p
+ for n, p in named_parameters
+ if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n)
+ ]
+ rest_pretrained_params = [
+ p
+ for n, p in named_parameters
+ if (include(n, p) and p.requires_grad) and is_pretrained_params(n)
+ ]
+ gain_or_bias_new_params = [
+ p
+ for n, p in named_parameters
+ if (exclude(n, p) and p.requires_grad)
+ and (not is_pretrained_params(n))
+ ]
+ rest_new_params = [
+ p
+ for n, p in named_parameters
+ if (include(n, p) and p.requires_grad)
+ and (not is_pretrained_params(n))
+ ]
+
+ pretrained_params_optimizer = get_optimizer(
+ [
+ {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0},
+ {
+ "params": rest_pretrained_params,
+ "weight_decay": args.wd_pretrained,
+ },
+ ],
+ lr=args.lr_pretrained,
+ betas=(args.beta1_pretrained, args.beta2_pretrained),
+ eps=args.eps_pretrained,
+ momentum=args.momentum_pretrained,
+ optimizer_name=args.optimizer,
+ )
+ pretrained_params_scheduler = cosine_lr(
+ pretrained_params_optimizer,
+ args.lr_pretrained,
+ args.warmup,
+ total_steps,
+ )
+
+ new_params_optimizer = get_optimizer(
+ [
+ {"params": gain_or_bias_new_params, "weight_decay": 0.0},
+ {"params": rest_new_params, "weight_decay": args.wd_new},
+ ],
+ lr=args.lr_new,
+ betas=(args.beta1_new, args.beta2_new),
+ eps=args.eps_new,
+ momentum=args.momentum_new,
+ optimizer_name=args.optimizer,
+ )
+ new_params_scheduler = cosine_lr(
+ new_params_optimizer, args.lr_new, args.warmup, total_steps
+ )
+
+ optimizer["text"] = pretrained_params_optimizer
+ optimizer["audio"] = new_params_optimizer
+ scheduler["text"] = pretrained_params_scheduler
+ scheduler["audio"] = new_params_scheduler
+
+ if args.horovod:
+ pretrained_params_optimizer = hvd.DistributedOptimizer(
+ pretrained_params_optimizer,
+ named_parameters=model.named_parameters(),
+ )
+ new_params_optimizer = hvd.DistributedOptimizer(
+ new_params_optimizer, named_parameters=model.named_parameters()
+ )
+ hvd.broadcast_parameters(model.state_dict(), root_rank=0)
+ hvd.broadcast_optimizer_state(
+ pretrained_params_optimizer, root_rank=0
+ )
+ hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0)
+ else:
+
+ optimizer["clap"] = get_optimizer(
+ [
+ {"params": gain_or_bias_params, "weight_decay": 0.0},
+ {"params": rest_params, "weight_decay": args.wd},
+ ],
+ lr=args.lr,
+ betas=(args.beta1, args.beta2),
+ eps=args.eps,
+ momentum=args.momentum,
+ optimizer_name=args.optimizer,
+ )
+ scheduler["clap"] = cosine_lr(
+ optimizer["clap"], args.lr, args.warmup, total_steps
+ )
+
+ if args.horovod:
+ optimizer["clap"] = hvd.DistributedOptimizer(
+ optimizer["clap"], named_parameters=model.named_parameters()
+ )
+ hvd.broadcast_parameters(model.state_dict(), root_rank=0)
+ hvd.broadcast_optimizer_state(optimizer["clap"], root_rank=0)
+
+ # linear probe optimizer
+ else:
+ lp_params = [
+ p for n, p in named_parameters if (not in_clap(n, p)) and p.requires_grad
+ ]
+ lp_optim = get_optimizer(
+ lp_params,
+ lr=args.lp_lr,
+ betas=(args.beta1, args.beta2),
+ eps=args.eps,
+ momentum=0.9,
+ optimizer_name=args.optimizer,
+ )
+ optimizer["lp"] = lp_optim
+
+ return optimizer, scheduler, text_freeze_parameters
+
+
+def main():
+ args = parse_args()
+
+ time.sleep(args.sleep)
+
+ # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
+ args.amodel = args.amodel.replace("/", "-")
+ # download sizes.json file
+
+ # (yusong): the below two lines are for debug
+ # print("setting up faulthandler")
+ # faulthandler.register(10)
+
+ random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ np.random.seed(args.seed)
+ args.class_index_dict = load_class_label(args.class_label_path)
+
+ # get the name of the experiments
+ if args.name is None:
+ args.name = "-".join(
+ [
+ datetime.now().strftime("%Y_%m_%d-%H_%M_%S"),
+ f"linear_probe" f"model_{args.amodel}",
+ f"lr_{args.lr}",
+ f"b_{args.batch_size}",
+ f"j_{args.workers}",
+ f"p_{args.precision}",
+ ]
+ )
+
+ # discover initial world args early so we can log properly
+ args.distributed = False
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
+
+ if args.remotedata and is_master(args):
+ for dataset_name in args.datasetnames:
+ for split in dataset_split[dataset_name]:
+ if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
+ os.makedirs(f"./json_files/{dataset_name}/{split}")
+ os.system(
+ f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
+ )
+
+ args.log_path = None
+ if is_master(args, local=args.log_local):
+ log_base_path = os.path.join(args.logs, args.name)
+ os.makedirs(log_base_path, exist_ok=True)
+ log_filename = f"out-{args.rank}" if args.log_local else "out.log"
+ args.log_path = os.path.join(log_base_path, log_filename)
+
+ # avoid log dir in same name:
+ postfix = 0
+ while os.path.exists(args.log_path):
+ postfix += 1
+ log_base_path_new = log_base_path + "-" + str(postfix)
+ os.makedirs(log_base_path_new, exist_ok=True)
+ log_filename = f"out-{args.rank}" if args.log_local else "out.log"
+ args.log_path = os.path.join(log_base_path_new, log_filename)
+ # print(
+ # "Error. Experiment already exists. Use --name {} to specify a new experiment."
+ # )
+ # return -1
+
+ # Set logger
+ args.log_level = logging.DEBUG if args.debug else logging.INFO
+ setup_logging(args.log_path, args.log_level)
+
+ # fully initialize distributed device environment
+ device = init_distributed_device(args)
+
+ args.wandb = "wandb" in args.report_to or "all" in args.report_to
+ args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to
+ if is_master(args):
+ args.tensorboard_path = (
+ os.path.join(args.logs, args.name, "tensorboard")
+ if args.tensorboard
+ else ""
+ )
+ args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
+ for dirname in [args.tensorboard_path, args.checkpoint_path]:
+ if dirname:
+ os.makedirs(dirname, exist_ok=True)
+ else:
+ args.tensorboard_path = ""
+ args.checkpoint_path = ""
+
+ if args.copy_codebase:
+ copy_codebase(args)
+
+ assert args.precision in ["amp", "fp16", "fp32"]
+ if args.precision == "fp16":
+ logging.warning(
+ "It is recommended to use AMP mixed-precision instead of FP16. "
+ "FP16 support needs further verification and tuning, especially for train."
+ )
+
+ if args.horovod:
+ logging.info(
+ f"Running in horovod mode with multiple processes / nodes. Device: {args.device}."
+ f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
+ )
+ elif args.distributed:
+ logging.info(
+ f"Running in distributed mode with multiple processes. Device: {args.device}."
+ f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
+ )
+ else:
+ logging.info(f"Running with a single process. Device {args.device}.")
+
+ logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}")
+
+ # Create CLAP model
+ clap_model, clap_model_cfg = create_model(
+ args.amodel,
+ args.tmodel,
+ args.pretrained,
+ precision=args.precision,
+ device=device,
+ jit=args.torchscript,
+ force_quick_gelu=args.force_quick_gelu,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=False,
+ pretrained_audio=args.pretrained_audio,
+ pretrained_text=args.pretrained_text,
+ enable_fusion=args.enable_fusion,
+ fusion_type=args.fusion_type,
+ )
+
+ args.lp_out_ch = len(list(args.class_index_dict.keys()))
+ # Linear Probe
+ logging.info(f"linear probe using mlp: {args.lp_mlp}")
+ logging.info(f"linear probe using freeze: {args.lp_freeze}")
+ logging.info(f"linear probe act layer: {args.lp_act}")
+ logging.info(f"linear probe out ch: {args.lp_out_ch}")
+ logging.info(f"linear probe learning rate (if applicable): {args.lp_lr}")
+ logging.info(f"linear probe loss func: {args.lp_loss}")
+ logging.info(f"linear probe lp_metrics: {args.lp_metrics}")
+
+ model = LinearProbe(
+ clap_model,
+ mlp=args.lp_mlp,
+ freeze=args.lp_freeze,
+ in_ch=512,
+ out_ch=args.lp_out_ch,
+ act=args.lp_act,
+ ) # in_ch is fixed (i.e., 512)
+ model = model.to(device)
+
+ if args.horovod:
+ with torch.no_grad():
+ for param in model.parameters():
+ param.set_(param.contiguous())
+
+ if args.trace:
+ model = trace_model(model, batch_size=args.batch_size, device=device)
+
+ if is_master(args):
+ logging.info("Linear Probe CLAP Model:")
+ logging.info(f"{str(clap_model)}")
+ logging.info("Params:")
+ params_file = os.path.join(args.logs, args.name, "params.txt")
+ with open(params_file, "w") as f:
+ for name in sorted(vars(args)):
+ val = getattr(args, name)
+ logging.info(f" {name}: {val}")
+ f.write(f"{name}: {val}\n")
+
+ if args.distributed and not args.horovod:
+ if args.use_bn_sync:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ ddp_args = {}
+ if args.ddp_static_graph:
+ # this doesn't exist in older PyTorch, arg only added if enabled
+ ddp_args["static_graph"] = True
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[device], find_unused_parameters=True, **ddp_args
+ )
+
+ data = get_data(args, clap_model_cfg)
+ assert len(data), "At least one train or eval dataset must be specified."
+ if args.trace:
+ assert "train" not in data, "Cannot train with traced model"
+
+ optimizer, scheduler, text_freeze_parameters = config_lp_optimizer(
+ model, data, args
+ )
+
+ scaler = GradScaler() if args.precision == "amp" else None
+
+ # optionally resume from a checkpoint
+ start_epoch = 0
+ if args.resume is not None:
+ if os.path.isfile(args.resume):
+ checkpoint = torch.load(args.resume, map_location=device)
+ if "epoch" in checkpoint:
+ # resuming a train checkpoint w/ epoch and optimizer state
+ start_epoch = checkpoint["epoch"]
+ sd = checkpoint["state_dict"]
+ if not args.distributed and next(iter(sd.items()))[0].startswith(
+ "module"
+ ):
+ sd = {k[len("module.") :]: v for k, v in sd.items()}
+ model.load_state_dict(sd)
+ if args.split_opt:
+ if optimizer is not None:
+ for k, o_ in optimizer.items():
+ o_.load_state_dict(checkpoint[k + "_" + "optimizer"])
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if scaler is not None and "scaler" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler"])
+ logging.info(
+ f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})"
+ )
+ else:
+ # loading a bare (model only) checkpoint for fine-tune or evaluation
+ model.load_state_dict(checkpoint)
+ logging.info(
+ f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})"
+ )
+ if args.freeze_text:
+ print("Freeze Text!!!!")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+ else:
+ logging.info("=> no checkpoint found at '{}'".format(args.resume))
+
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+
+ # determine if this worker should save logs and checkpoints. only do so if it is rank == 0
+ args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args)
+ writer = None
+ if args.save_logs and args.tensorboard:
+ assert tensorboard is not None, "Please install tensorboard."
+ writer = tensorboard.SummaryWriter(args.tensorboard_path)
+
+ if args.wandb and is_master(args):
+ assert wandb is not None, "Please install wandb."
+ logging.debug("Starting wandb.")
+ args.train_sz = data["train"].dataloader.num_samples
+ if args.val_data is not None:
+ args.val_sz = data["val"].dataloader.num_samples
+ # you will have to configure this for your project!
+ wandb.init(
+ project="clap",
+ notes=args.wandb_notes,
+ name=args.wandb_notes,
+ tags=[],
+ config=vars(args),
+ )
+ if args.debug:
+ wandb.watch(model, log="all")
+ wandb.save(params_file)
+ logging.debug("Finished loading wandb.")
+
+ if "train" not in data:
+ evaluate(model, data, start_epoch, args, writer)
+ return
+ elif start_epoch == 0 and "val" in data and not args.no_eval:
+ evaluate(model, data, 0, args, writer)
+ if args.save_top_performance:
+ current_top_k_ckpt_metrics = {
+ i: 0 for i in range(args.save_top_performance)
+ } # initialize the top-k metric for ckpts to 0
+
+ for epoch in range(start_epoch, args.epochs):
+ # freeze the text param after (include) args.freeze_text_after, this is -1 by default
+ if epoch == args.freeze_text_after:
+ print("Text pretrained parameters are freezed since this epoch.")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+ if is_master(args):
+ logging.info(f"Start epoch {epoch}")
+
+ train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer)
+ completed_epoch = epoch + 1
+
+ if (
+ any(v in data for v in ("val", "imagenet-val", "imagenet-v2"))
+ and not args.no_eval
+ ):
+ metrics = evaluate(model, data, completed_epoch, args, writer)
+ if args.save_top_performance:
+ top_k_dataset = args.top_k_checkpoint_select_dataset
+ top_k_metric = args.top_k_checkpoint_select_metric
+ filtered_metrics = [
+ v
+ for k, v in metrics.items()
+ if top_k_metric in k and top_k_dataset in k
+ ] # check all R@10 metrics (all dataset) and use it to update the ckpt
+ # Saving checkpoints.
+ if args.save_logs:
+ opt_dict = {
+ k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items()
+ }
+ checkpoint_dict = {
+ "epoch": completed_epoch,
+ "name": args.name,
+ "state_dict": model.state_dict(),
+ }
+ checkpoint_dict.update(opt_dict)
+ if scaler is not None:
+ checkpoint_dict["scaler"] = scaler.state_dict()
+
+ if completed_epoch == args.epochs or (
+ args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
+ ):
+ torch.save(
+ checkpoint_dict,
+ os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
+ )
+ if args.save_most_recent:
+ torch.save(
+ checkpoint_dict,
+ os.path.join(args.checkpoint_path, f"epoch_latest.pt"),
+ )
+ if args.save_top_performance and not args.no_eval:
+ update_top_k_performance(
+ filtered_metrics,
+ current_top_k_ckpt_metrics,
+ args,
+ checkpoint_dict,
+ bignumbetter=True,
+ )
+
+ if args.wandb and is_master(args):
+ wandb.finish()
+
+
+def copy_codebase(args):
+ from shutil import copytree, ignore_patterns
+
+ new_code_path = os.path.join(args.logs, args.name, "code")
+ if os.path.exists(new_code_path):
+ print(
+ f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
+ )
+ return -1
+ print(f"Copying codebase to {new_code_path}")
+ current_code_path = os.path.realpath(__file__)
+ for _ in range(3):
+ current_code_path = os.path.dirname(current_code_path)
+ copytree(
+ current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb")
+ )
+ print("Done copying code.")
+ return 1
+
+
+if __name__ == "__main__":
+ main()
diff --git a/audioldm/clap/training/lp_train.py b/audioldm/clap/training/lp_train.py
new file mode 100755
index 0000000000000000000000000000000000000000..24a19bacd0a4b789415cfccbce1f8bc99bc493ed
--- /dev/null
+++ b/audioldm/clap/training/lp_train.py
@@ -0,0 +1,301 @@
+import json
+import logging
+import math
+import os
+import time
+from contextlib import suppress
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+from open_clip import LPLoss, LPMetrics, lp_gather_features
+from open_clip.utils import do_mixup, get_mix_lambda
+from .distributed import is_master
+from .zero_shot import zero_shot_eval
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def unwrap_model(model):
+ if hasattr(model, "module"):
+ return model.module
+ else:
+ return model
+
+
+def train_one_epoch(
+ model,
+ data,
+ epoch,
+ optimizer,
+ scaler,
+ scheduler,
+ args,
+ tb_writer=None,
+ extra_suffix="",
+):
+ device = torch.device(args.device)
+ autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
+ model.train()
+ loss = LPLoss(args.lp_loss)
+
+ dataloader, sampler = data["train"].dataloader, data["train"].sampler
+ if args.distributed and sampler is not None:
+ sampler.set_epoch(epoch)
+ num_batches_per_epoch = dataloader.num_batches
+ sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
+
+ # for toy dataset
+ if args.dataset_type == "toy":
+ dataloader.dataset.generate_queue()
+
+ loss_m = AverageMeter()
+ batch_time_m = AverageMeter()
+ data_time_m = AverageMeter()
+ end = time.time()
+
+ for i, batch in enumerate(dataloader):
+ step = num_batches_per_epoch * epoch + i
+
+ if isinstance(scheduler, dict):
+ for s in scheduler.values():
+ s(step)
+ else:
+ scheduler(step)
+
+ audio = batch # contains mel_spec, wavform, and longer list
+ class_label = batch["class_label"]
+ # audio = audio.to(device=device, non_blocking=True)
+ class_label = class_label.to(device=device, non_blocking=True)
+
+ if args.mixup:
+ # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146
+ mix_lambda = torch.from_numpy(
+ get_mix_lambda(0.5, len(audio["waveform"]))
+ ).to(device)
+ class_label = do_mixup(class_label, mix_lambda)
+ else:
+ mix_lambda = None
+
+ data_time_m.update(time.time() - end)
+ if isinstance(optimizer, dict):
+ for o_ in optimizer.values():
+ o_.zero_grad()
+ else:
+ optimizer.zero_grad()
+
+ with autocast():
+ pred = model(audio, mix_lambda=mix_lambda, device=device)
+ total_loss = loss(pred, class_label)
+
+ if isinstance(optimizer, dict):
+ if scaler is not None:
+ scaler.scale(total_loss).backward()
+ for o_ in optimizer.values():
+ if args.horovod:
+ o_.synchronize()
+ scaler.unscale_(o_)
+ with o_.skip_synchronize():
+ scaler.step(o_)
+ else:
+ scaler.step(o_)
+ scaler.update()
+ else:
+ total_loss.backward()
+ for o_ in optimizer.values():
+ o_.step()
+ else:
+ if scaler is not None:
+ scaler.scale(total_loss).backward()
+ if args.horovod:
+ optimizer.synchronize()
+ scaler.unscale_(optimizer)
+ with optimizer.skip_synchronize():
+ scaler.step(optimizer)
+ else:
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ total_loss.backward()
+ optimizer.step()
+
+ # Note: we clamp to 4.6052 = ln(100), as in the original paper.
+ with torch.no_grad():
+ unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100))
+ unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100))
+
+ batch_time_m.update(time.time() - end)
+ end = time.time()
+ batch_count = i + 1
+
+ if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
+ if isinstance(audio, dict):
+ batch_size = len(audio["waveform"])
+ else:
+ batch_size = len(audio)
+ num_samples = batch_count * batch_size * args.world_size
+ samples_per_epoch = dataloader.num_samples
+ percent_complete = 100.0 * batch_count / num_batches_per_epoch
+
+ # NOTE loss is coarsely sampled, just master node and per log update
+ loss_m.update(total_loss.item(), batch_size)
+ if isinstance(optimizer, dict):
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f} "
+ f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}"
+ )
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
+ }
+ else:
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f} "
+ f"LR: {optimizer.param_groups[0]['lr']:5f} "
+ )
+
+ # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "lr": optimizer.param_groups[0]["lr"],
+ }
+ for name, val in log_data.items():
+ name = f"train{extra_suffix}/{name}"
+ if tb_writer is not None:
+ tb_writer.add_scalar(name, val, step)
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+ wandb.log({name: val, "step": step})
+
+ # resetting batch / data time meters per log window
+ batch_time_m.reset()
+ data_time_m.reset()
+ # end for
+
+
+def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""):
+ metrics = {}
+ if not args.parallel_eval:
+ if not is_master(args):
+ return metrics
+ device = torch.device(args.device)
+ model.eval()
+
+ # CHANGE
+ # zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
+ # metrics.update(zero_shot_metrics)
+ if is_master(args):
+ print("Evaluating...")
+ metric_names = args.lp_metrics.split(",")
+ eval_tool = LPMetrics(metric_names=metric_names)
+
+ autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
+ if "val" in data and (
+ args.val_frequency
+ and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
+ ):
+ if args.parallel_eval:
+ dataloader, sampler = data["val"].dataloader, data["val"].sampler
+ if args.distributed and sampler is not None:
+ sampler.set_epoch(epoch)
+ samples_per_val = dataloader.num_samples
+ else:
+ dataloader = data["val"].dataloader
+ num_samples = 0
+ samples_per_val = dataloader.num_samples
+
+ eval_info = {"pred": [], "target": []}
+ with torch.no_grad():
+ for i, batch in enumerate(dataloader):
+ audio = batch # contains mel_spec, wavform, and longer list
+ class_label = batch["class_label"]
+
+ # audio = audio.to(device=device, non_blocking=True)
+ class_label = class_label.to(device=device, non_blocking=True)
+
+ with autocast():
+ pred = model(audio, device=device)
+ if args.parallel_eval:
+ pred, class_label = lp_gather_features(
+ pred, class_label, args.world_size, args.horovod
+ )
+ eval_info["pred"].append(pred)
+ eval_info["target"].append(class_label)
+
+ num_samples += class_label.shape[0]
+
+ if (i % 100) == 0: # and i != 0:
+ logging.info(
+ f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
+ )
+
+ if is_master(args):
+ eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu()
+ eval_info["target"] = torch.cat(eval_info["target"], 0).cpu()
+ metric_dict = eval_tool.evaluate_mertics(
+ eval_info["pred"], eval_info["target"]
+ )
+ metrics.update(metric_dict)
+ if "epoch" not in metrics.keys():
+ metrics.update({"epoch": epoch})
+
+ if is_master(args):
+ if not metrics:
+ return metrics
+
+ logging.info(
+ f"Eval Epoch: {epoch} "
+ + "\n".join(
+ ["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics]
+ )
+ )
+ if args.save_logs:
+ for name, val in metrics.items():
+ if tb_writer is not None:
+ tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch)
+
+ with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
+ f.write(json.dumps(metrics))
+ f.write("\n")
+
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+ for name, val in metrics.items():
+ wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch})
+
+ return metrics
+ else:
+ return metrics
diff --git a/audioldm/clap/training/main.py b/audioldm/clap/training/main.py
new file mode 100755
index 0000000000000000000000000000000000000000..3b563a5d001be7adfbe779dee7ad8ac49aadc50d
--- /dev/null
+++ b/audioldm/clap/training/main.py
@@ -0,0 +1,596 @@
+from inspect import getargs
+import logging
+import os
+import random
+from datetime import datetime
+import bisect
+import copy
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+from torch import optim
+from torch.cuda.amp import GradScaler
+import faulthandler
+import pathlib
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+try:
+ import torch.utils.tensorboard as tensorboard
+except ImportError:
+ tensorboard = None
+
+try:
+ import horovod.torch as hvd
+except ImportError:
+ hvd = None
+
+from open_clip import create_model_and_transforms, trace_model, create_model
+from training.data import get_data
+from training.distributed import is_master, init_distributed_device, world_info_from_env
+from training.logger import setup_logging
+from training.params import parse_args
+from training.scheduler import cosine_lr
+from training.train import train_one_epoch, evaluate
+from open_clip.utils import dataset_split, get_optimizer
+
+
+def maintain_ckpts(args, startidx, all_idx_len):
+ for i in reversed(range(startidx, all_idx_len)):
+ if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")):
+ os.rename(
+ os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
+ os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"),
+ )
+ if os.path.exists(
+ os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")
+ ):
+ os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt"))
+ return
+
+
+def update_top_k_performance(
+ new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True
+):
+ """
+ Record the top-k performance of the current epoch.
+ current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...}
+ """
+ if isinstance(new_metrics_inputs, (list, tuple)):
+ new_metrics_inputs = np.mean(new_metrics_inputs)
+ return update_top_k_performance(
+ new_metrics_inputs,
+ current_top_k_ckpt_metrics,
+ args=args,
+ ckpt=ckpt,
+ bignumbetter=bignumbetter,
+ )
+ elif isinstance(new_metrics_inputs, dict):
+ new_metrics_inputs = np.mean(list(new_metrics_inputs.values()))
+ return update_top_k_performance(
+ new_metrics_inputs,
+ current_top_k_ckpt_metrics,
+ args=args,
+ ckpt=ckpt,
+ bignumbetter=bignumbetter,
+ )
+ elif isinstance(new_metrics_inputs, (float, int)):
+ update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()}
+ sorted_keys = sorted(current_top_k_ckpt_metrics.keys())
+ sorted_values = sorted(
+ current_top_k_ckpt_metrics.values(), reverse=bignumbetter
+ )
+ sorted_values_ = copy.deepcopy(sorted_values)
+ sorted_values.append(new_metrics_inputs)
+ sorted_values = sorted(sorted_values, reverse=bignumbetter)
+ sorted_values = sorted_values[:-1]
+
+ if sorted_values == sorted_values_:
+ return current_top_k_ckpt_metrics, new_metrics_inputs
+ else:
+ for i in range(len(sorted_keys)):
+ if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]:
+ current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i]
+ update_flag[sorted_keys[i]] = True
+ for i in range(len(update_flag)):
+ if update_flag[i]:
+ maintain_ckpts(args, i, len(sorted_keys))
+ torch.save(
+ ckpt,
+ os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"),
+ )
+ break
+ return current_top_k_ckpt_metrics, new_metrics_inputs
+
+
+# def updateifNone(a, b):
+# a = b if None else a
+# return a
+
+
+def is_pretrained_params(n):
+ return (
+ n.startswith("transformer")
+ or n in ["positional_embedding", "text_projection"]
+ or n.startswith("token_embedding")
+ or n.startswith("ln_final")
+ or n.startswith("logit_scale_t")
+ )
+
+
+def random_seed(seed=42, rank=0):
+ torch.manual_seed(seed + rank)
+ np.random.seed(seed + rank)
+ random.seed(seed + rank)
+
+
+def main():
+ args = parse_args()
+ # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
+ args.amodel = args.amodel.replace("/", "-")
+ # download sizes.json file
+
+ # (yusong): the below two lines are for debug
+ # print("setting up faulthandler")
+ # faulthandler.register(10)
+
+ random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ np.random.seed(args.seed)
+ if args.tmodel == "bert" or args.tmodel == "roberta" or args.tmodel == "bart":
+ assert (
+ args.pretrained == "" or args.pretrained is None
+ ), "bert/roberta/bart text encoder does not support pretrained models."
+
+ # get the name of the experiments
+ if args.name is None:
+ args.name = "-".join(
+ [
+ datetime.now().strftime("%Y_%m_%d-%H_%M_%S"),
+ f"model_{args.amodel}",
+ f"lr_{args.lr}",
+ f"b_{args.batch_size}",
+ f"j_{args.workers}",
+ f"p_{args.precision}",
+ ]
+ )
+
+ # discover initial world args early so we can log properly
+ args.distributed = False
+ args.local_rank, args.rank, args.world_size = world_info_from_env()
+
+ if args.remotedata and is_master(args):
+ for dataset_name in args.datasetnames:
+ for split in dataset_split[dataset_name]:
+ if not os.path.exists(f"./json_files/{dataset_name}/{split}"):
+ os.makedirs(f"./json_files/{dataset_name}/{split}")
+ os.system(
+ f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json"
+ )
+
+ args.log_path = None
+ if is_master(args, local=args.log_local):
+ log_base_path = os.path.join(args.logs, args.name)
+ os.makedirs(log_base_path, exist_ok=True)
+ log_filename = f"out-{args.rank}" if args.log_local else "out.log"
+ args.log_path = os.path.join(log_base_path, log_filename)
+ if os.path.exists(args.log_path):
+ print(
+ "Error. Experiment already exists. Use --name {} to specify a new experiment."
+ )
+ return -1
+
+ # Set logger
+ args.log_level = logging.DEBUG if args.debug else logging.INFO
+ setup_logging(args.log_path, args.log_level)
+
+ # fully initialize distributed device environment
+ device = init_distributed_device(args)
+
+ args.wandb = "wandb" in args.report_to or "all" in args.report_to
+ args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to
+ if is_master(args):
+ args.tensorboard_path = (
+ os.path.join(args.logs, args.name, "tensorboard")
+ if args.tensorboard
+ else ""
+ )
+ args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
+ for dirname in [args.tensorboard_path, args.checkpoint_path]:
+ if dirname:
+ os.makedirs(dirname, exist_ok=True)
+ else:
+ args.tensorboard_path = ""
+ args.checkpoint_path = ""
+
+ if args.copy_codebase:
+ copy_codebase(args)
+
+ assert args.precision in ["amp", "fp16", "fp32"]
+ if args.precision == "fp16":
+ logging.warning(
+ "It is recommended to use AMP mixed-precision instead of FP16. "
+ "FP16 support needs further verification and tuning, especially for train."
+ )
+
+ if args.horovod:
+ logging.info(
+ f"Running in horovod mode with multiple processes / nodes. Device: {args.device}."
+ f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
+ )
+ elif args.distributed:
+ logging.info(
+ f"Running in distributed mode with multiple processes. Device: {args.device}."
+ f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}."
+ )
+ else:
+ logging.info(f"Running with a single process. Device {args.device}.")
+
+ logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}")
+
+ model, model_cfg = create_model(
+ args.amodel,
+ args.tmodel,
+ args.pretrained,
+ precision=args.precision,
+ device=device,
+ jit=args.torchscript,
+ force_quick_gelu=args.force_quick_gelu,
+ openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir),
+ skip_params=True,
+ pretrained_audio=args.pretrained_audio,
+ pretrained_text=args.pretrained_text,
+ enable_fusion=args.enable_fusion,
+ fusion_type=args.fusion_type,
+ )
+
+ if args.horovod:
+ with torch.no_grad():
+ for param in model.parameters():
+ param.set_(param.contiguous())
+
+ if args.trace:
+ model = trace_model(model, batch_size=args.batch_size, device=device)
+
+ if is_master(args):
+ logging.info("Model:")
+ logging.info(f"{str(model)}")
+ logging.info("Params:")
+ params_file = os.path.join(args.logs, args.name, "params.txt")
+ with open(params_file, "w") as f:
+ for name in sorted(vars(args)):
+ val = getattr(args, name)
+ logging.info(f" {name}: {val}")
+ f.write(f"{name}: {val}\n")
+
+ if args.distributed and not args.horovod:
+ if args.use_bn_sync:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ ddp_args = {}
+ if args.ddp_static_graph:
+ # this doesn't exist in older PyTorch, arg only added if enabled
+ ddp_args["static_graph"] = True
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[device], find_unused_parameters=True, **ddp_args
+ )
+
+ data = get_data(args, model_cfg)
+ assert len(data), "At least one train or eval dataset must be specified."
+ if args.trace:
+ assert "train" not in data, "Cannot train with traced model"
+
+ exclude = (
+ lambda n, p: p.ndim < 2
+ or "bn" in n
+ or "ln" in n
+ or "bias" in n
+ or "logit_scale" in n
+ )
+ include = lambda n, p: not exclude(n, p)
+
+ named_parameters = list(model.named_parameters())
+
+ # freeze text encoder
+ text_freeze_parameters = [p for n, p in named_parameters if "text_branch" in n]
+
+ if args.freeze_text:
+ print("Freeze Text!!!!")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+
+ gain_or_bias_params = [
+ p for n, p in named_parameters if exclude(n, p) and p.requires_grad
+ ]
+ rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]
+
+ # set wd-related params to 0 if use adam optimizer
+ if args.optimizer == "adam":
+ args.wd = 0
+ args.wd_pretrained = 0
+ args.wd_new = 0
+
+ if args.train_data is None:
+ optimizer = None
+ scheduler = None
+ else:
+ total_steps = data["train"].dataloader.num_batches * args.epochs
+
+ if args.split_opt:
+ for x in ["lr", "beta1", "beta2", "eps", "wd"]:
+ for y in ["_new", "_pretrained"]:
+ if getattr(args, x + y) is None:
+ setattr(args, x + y, getattr(args, x))
+
+ gain_or_bias_pretrained_params = [
+ p
+ for n, p in named_parameters
+ if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n)
+ ]
+ rest_pretrained_params = [
+ p
+ for n, p in named_parameters
+ if (include(n, p) and p.requires_grad) and is_pretrained_params(n)
+ ]
+ gain_or_bias_new_params = [
+ p
+ for n, p in named_parameters
+ if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n))
+ ]
+ rest_new_params = [
+ p
+ for n, p in named_parameters
+ if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n))
+ ]
+ pretrained_params_optimizer = get_optimizer(
+ [
+ {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0},
+ {
+ "params": rest_pretrained_params,
+ "weight_decay": args.wd_pretrained,
+ },
+ ],
+ lr=args.lr_pretrained,
+ betas=(args.beta1_pretrained, args.beta2_pretrained),
+ eps=args.eps_pretrained,
+ momentum=args.momentum_pretrained,
+ optimizer_name=args.optimizer,
+ )
+ pretrained_params_scheduler = cosine_lr(
+ pretrained_params_optimizer,
+ args.lr_pretrained,
+ args.warmup,
+ total_steps,
+ )
+ new_params_optimizer = get_optimizer(
+ [
+ {"params": gain_or_bias_new_params, "weight_decay": 0.0},
+ {"params": rest_new_params, "weight_decay": args.wd_new},
+ ],
+ lr=args.lr_new,
+ betas=(args.beta1_new, args.beta2_new),
+ eps=args.eps_new,
+ momentum=args.momentum_new,
+ optimizer_name=args.optimizer,
+ )
+
+ new_params_scheduler = cosine_lr(
+ new_params_optimizer, args.lr_new, args.warmup, total_steps
+ )
+
+ optimizer = {
+ "pretrained": pretrained_params_optimizer,
+ "new": new_params_optimizer,
+ }
+ scheduler = {
+ "pretrained": pretrained_params_scheduler,
+ "new": new_params_scheduler,
+ }
+
+ if args.horovod:
+ pretrained_params_optimizer = hvd.DistributedOptimizer(
+ pretrained_params_optimizer,
+ named_parameters=model.named_parameters(),
+ )
+ new_params_optimizer = hvd.DistributedOptimizer(
+ new_params_optimizer, named_parameters=model.named_parameters()
+ )
+ hvd.broadcast_parameters(model.state_dict(), root_rank=0)
+ hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0)
+ hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0)
+ else:
+ optimizer = get_optimizer(
+ [
+ {"params": gain_or_bias_params, "weight_decay": 0.0},
+ {"params": rest_params, "weight_decay": args.wd},
+ ],
+ lr=args.lr,
+ betas=(args.beta1, args.beta2),
+ eps=args.eps,
+ momentum=args.momentum,
+ optimizer_name=args.optimizer,
+ )
+
+ scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)
+
+ if args.horovod:
+ optimizer = hvd.DistributedOptimizer(
+ optimizer, named_parameters=model.named_parameters()
+ )
+ hvd.broadcast_parameters(model.state_dict(), root_rank=0)
+ hvd.broadcast_optimizer_state(optimizer, root_rank=0)
+
+ scaler = GradScaler() if args.precision == "amp" else None
+
+ # optionally resume from a checkpoint
+ start_epoch = 0
+ if args.resume is not None:
+ if os.path.isfile(args.resume):
+ checkpoint = torch.load(args.resume, map_location=device)
+ if "epoch" in checkpoint:
+ # resuming a train checkpoint w/ epoch and optimizer state
+ start_epoch = checkpoint["epoch"]
+ sd = checkpoint["state_dict"]
+ if not args.distributed and next(iter(sd.items()))[0].startswith(
+ "module"
+ ):
+ sd = {k[len("module.") :]: v for k, v in sd.items()}
+ model.load_state_dict(sd)
+ if args.split_opt:
+ if optimizer is not None:
+ for k, o_ in optimizer.items():
+ o_.load_state_dict(checkpoint[k + "_" + "optimizer"])
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if scaler is not None and "scaler" in checkpoint:
+ scaler.load_state_dict(checkpoint["scaler"])
+ logging.info(
+ f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})"
+ )
+ else:
+ # loading a bare (model only) checkpoint for fine-tune or evaluation
+ model.load_state_dict(checkpoint)
+ logging.info(
+ f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})"
+ )
+ if args.freeze_text:
+ print("Freeze Text!!!!")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+ else:
+ logging.info("=> no checkpoint found at '{}'".format(args.resume))
+
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+
+ # determine if this worker should save logs and checkpoints. only do so if it is rank == 0
+ args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args)
+ writer = None
+ if args.save_logs and args.tensorboard:
+ assert tensorboard is not None, "Please install tensorboard."
+ writer = tensorboard.SummaryWriter(args.tensorboard_path)
+
+ if args.wandb and is_master(args):
+ assert wandb is not None, "Please install wandb."
+ logging.debug("Starting wandb.")
+ args.train_sz = data["train"].dataloader.num_samples
+ if args.val_data is not None:
+ args.val_sz = data["val"].dataloader.num_samples
+ # you will have to configure this for your project!
+ wandb.init(
+ project="clap",
+ notes=args.wandb_notes,
+ name=args.wandb_notes,
+ tags=[],
+ config=vars(args),
+ )
+ if args.debug:
+ wandb.watch(model, log="all")
+ wandb.save(params_file)
+ logging.debug("Finished loading wandb.")
+
+ if "train" not in data:
+ evaluate(model, data, start_epoch, args, writer)
+ return
+ elif start_epoch == 0 and "val" in data and not args.no_eval:
+ evaluate(model, data, 0, args, writer)
+ # print(f'rank {args.rank}, Start First Evaluation')# (yusong): for debug
+ if args.save_top_performance:
+ current_top_k_ckpt_metrics = {
+ i: 0 for i in range(args.save_top_performance)
+ } # initialize the top-k metric for ckpts to 0
+
+ # print(f'rank {args.rank}, Start Training') # (yusong): for debug
+ for epoch in range(start_epoch, args.epochs):
+ # freeze the text param after (include) args.freeze_text_after, this is -1 by default
+ if epoch == args.freeze_text_after:
+ print("Text pretrained parameters are freezed since this epoch.")
+ for k in text_freeze_parameters:
+ k.requires_grad = False
+ if is_master(args):
+ logging.info(f"Start epoch {epoch}")
+
+ train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer)
+ completed_epoch = epoch + 1
+
+ if (
+ any(v in data for v in ("val", "imagenet-val", "imagenet-v2"))
+ and not args.no_eval
+ ):
+ metrics = evaluate(model, data, completed_epoch, args, writer)
+ if args.save_top_performance:
+ top_k_dataset = args.top_k_checkpoint_select_dataset
+ top_k_metric = args.top_k_checkpoint_select_metric
+ filtered_metrics = [
+ v
+ for k, v in metrics.items()
+ if top_k_metric in k and top_k_dataset in k
+ ] # check all R@10 metrics (all dataset) and use it to update the ckpt
+ # Saving checkpoints.
+ if args.save_logs:
+ if args.split_opt:
+ opt_dict = {
+ k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items()
+ }
+ else:
+ opt_dict = {"optimizer": optimizer.state_dict()}
+ checkpoint_dict = {
+ "epoch": completed_epoch,
+ "name": args.name,
+ "state_dict": model.state_dict(),
+ }
+ checkpoint_dict.update(opt_dict)
+ if scaler is not None:
+ checkpoint_dict["scaler"] = scaler.state_dict()
+
+ if completed_epoch == args.epochs or (
+ args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
+ ):
+ torch.save(
+ checkpoint_dict,
+ os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"),
+ )
+ if args.save_most_recent:
+ torch.save(
+ checkpoint_dict,
+ os.path.join(args.checkpoint_path, f"epoch_latest.pt"),
+ )
+ if args.save_top_performance and not args.no_eval:
+ update_top_k_performance(
+ filtered_metrics,
+ current_top_k_ckpt_metrics,
+ args,
+ checkpoint_dict,
+ bignumbetter=True,
+ )
+
+ if args.wandb and is_master(args):
+ wandb.finish()
+
+
+def copy_codebase(args):
+ from shutil import copytree, ignore_patterns
+
+ new_code_path = os.path.join(args.logs, args.name, "code")
+ if os.path.exists(new_code_path):
+ print(
+ f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment."
+ )
+ return -1
+ print(f"Copying codebase to {new_code_path}")
+ current_code_path = os.path.realpath(__file__)
+ for _ in range(3):
+ current_code_path = os.path.dirname(current_code_path)
+ copytree(
+ current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb")
+ )
+ print("Done copying code.")
+ return 1
+
+
+if __name__ == "__main__":
+ main()
diff --git a/audioldm/clap/training/params.py b/audioldm/clap/training/params.py
new file mode 100755
index 0000000000000000000000000000000000000000..b1933e3a78ff583733846ea285d56eb0a0b892a5
--- /dev/null
+++ b/audioldm/clap/training/params.py
@@ -0,0 +1,569 @@
+import argparse
+import os
+
+CACHE_DIR = os.getenv(
+ "AUDIOLDM_CACHE_DIR",
+ "~/.cache")
+
+
+
+def get_default_params(model_name):
+ # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
+ model_name = model_name.lower()
+ if "vit" in model_name:
+ return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
+ else:
+ return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--train-data",
+ type=str,
+ default=None,
+ help="Path to h5 filewith training data",
+ )
+ parser.add_argument(
+ "--val-data",
+ type=str,
+ default=None,
+ help="Path to h5 file with validation data",
+ )
+ parser.add_argument(
+ "--freeze-text",
+ default=False,
+ action="store_true",
+ help="if you need to freeze the text encoder, make this True",
+ )
+ parser.add_argument(
+ "--freeze-text-after",
+ type=int,
+ default=-1,
+ help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it",
+ )
+ parser.add_argument(
+ "--train-ipc",
+ type=str,
+ default=None,
+ help="Path to npy file of the number of instance per class in training data",
+ )
+ parser.add_argument(
+ "--val-ipc",
+ type=str,
+ default=None,
+ help="Path to npy file of the number of instance per class in validation data",
+ )
+ parser.add_argument(
+ "--train-num-samples",
+ type=int,
+ default=None,
+ help="Number of samples in dataset. Required for webdataset if not available in info file.",
+ )
+ parser.add_argument(
+ "--val-num-samples",
+ type=int,
+ default=None,
+ help="Number of samples in dataset. Useful for webdataset if not available in info file.",
+ )
+ parser.add_argument(
+ "--dataset-type",
+ choices=["webdataset", "csv", "auto", "toy"],
+ default="auto",
+ help="Which type of dataset to process.",
+ )
+ parser.add_argument(
+ "--csv-separator",
+ type=str,
+ default="\t",
+ help="For csv-like datasets, which separator to use.",
+ )
+ parser.add_argument(
+ "--csv-img-key",
+ type=str,
+ default="filepath",
+ help="For csv-like datasets, the name of the key for the image paths.",
+ )
+ parser.add_argument(
+ "--csv-caption-key",
+ type=str,
+ default="title",
+ help="For csv-like datasets, the name of the key for the captions.",
+ )
+ parser.add_argument(
+ "--imagenet-val",
+ type=str,
+ default=None,
+ help="Path to imagenet val set for conducting zero shot evaluation.",
+ )
+ parser.add_argument(
+ "--imagenet-v2",
+ type=str,
+ default=None,
+ help="Path to imagenet v2 for conducting zero shot evaluation.",
+ )
+ parser.add_argument(
+ "--datasetnames",
+ nargs="+",
+ default=None,
+ help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects",
+ )
+ parser.add_argument(
+ "--full-train-dataset",
+ nargs="+",
+ default=None,
+ help="Which dataset will be trained with all the subsets. (train+test)",
+ )
+ parser.add_argument(
+ "--exclude-eval-dataset",
+ nargs="+",
+ default=None,
+ help="Which dataset will be excluded with evaluation",
+ )
+ parser.add_argument(
+ "--datasetinfos",
+ nargs="+",
+ default=None,
+ help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval",
+ )
+ parser.add_argument(
+ "--dataset-proportion",
+ type=float,
+ default=1.0,
+ help="How much proportion of dataset we want to train.",
+ )
+ parser.add_argument(
+ "--remotedata",
+ default=False,
+ action="store_true",
+ help="if the dataset is remote, set this flag",
+ )
+ parser.add_argument(
+ "--class-label-path",
+ type=str,
+ default=None,
+ help="The path of the class label pickle or csv.",
+ )
+ parser.add_argument(
+ "--datasetpath",
+ type=str,
+ default="/mnt/audio_clip/webdataset_tar",
+ help="The path to the dataset",
+ )
+ parser.add_argument(
+ "--logs",
+ type=str,
+ default="./logs/",
+ help="Where to store tensorboard logs. Use None to avoid storing logs.",
+ )
+ parser.add_argument(
+ "--log-local",
+ action="store_true",
+ default=False,
+ help="log files on local master, otherwise global master only.",
+ )
+ parser.add_argument(
+ "--name",
+ type=str,
+ default=None,
+ help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
+ )
+ parser.add_argument(
+ "--workers", type=int, default=1, help="Number of workers per GPU."
+ )
+ parser.add_argument(
+ "--batch-size", type=int, default=64, help="Batch size per GPU."
+ )
+ parser.add_argument(
+ "--epochs", type=int, default=32, help="Number of epochs to train for."
+ )
+ parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
+ parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
+ parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
+ parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
+ parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.")
+ parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
+
+ parser.add_argument(
+ "--split-opt",
+ action="store_true",
+ default=False,
+ help="Use this flag to skip the learning rate decay.",
+ )
+ parser.add_argument(
+ "--lr-pretrained", type=float, default=None, help="Learning rate for text."
+ )
+ parser.add_argument(
+ "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text."
+ )
+ parser.add_argument(
+ "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text."
+ )
+ parser.add_argument(
+ "--eps-pretrained", type=float, default=None, help="Adam epsilon for text."
+ )
+ parser.add_argument(
+ "--wd-pretrained", type=float, default=0.2, help="Weight decay for text."
+ )
+ parser.add_argument(
+ "--momentum-pretrained", type=float, default=0.9, help="Momentum for text."
+ )
+ parser.add_argument(
+ "--lr-new", type=float, default=None, help="Learning rate for audio."
+ )
+ parser.add_argument(
+ "--beta1-new", type=float, default=None, help="Adam beta 1 for audio."
+ )
+ parser.add_argument(
+ "--beta2-new", type=float, default=None, help="Adam beta 2 for audio."
+ )
+ parser.add_argument(
+ "--eps-new", type=float, default=None, help="Adam epsilon for audio."
+ )
+ parser.add_argument(
+ "--wd-new", type=float, default=0.2, help="Weight decay for audio."
+ )
+ parser.add_argument(
+ "--momentum-new", type=float, default=0.9, help="Momentum for audio."
+ )
+ parser.add_argument(
+ "--warmup", type=int, default=10000, help="Number of steps to warmup for."
+ )
+ parser.add_argument(
+ "--use-bn-sync",
+ default=False,
+ action="store_true",
+ help="Whether to use batch norm sync.",
+ )
+ parser.add_argument(
+ "--skip-scheduler",
+ action="store_true",
+ default=False,
+ help="Use this flag to skip the learning rate decay.",
+ )
+ parser.add_argument(
+ "--save-frequency", type=int, default=1, help="How often to save checkpoints."
+ )
+ parser.add_argument(
+ "--save-top-performance",
+ type=int,
+ default=0,
+ help="Save the top x performance weights if the value >0",
+ )
+ parser.add_argument(
+ "--save-most-recent",
+ action="store_true",
+ default=False,
+ help="Always save the most recent model trained to epoch_latest.pt.",
+ )
+ parser.add_argument(
+ "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
+ )
+ parser.add_argument(
+ "--val-frequency",
+ type=int,
+ default=1,
+ help="How often to run evaluation with val data.",
+ )
+ parser.add_argument(
+ "--resume",
+ default=None,
+ type=str,
+ help="path to latest checkpoint (default: none)",
+ )
+ parser.add_argument(
+ "--precision",
+ choices=["amp", "fp16", "fp32"],
+ default="amp",
+ help="Floating point precision.",
+ )
+ parser.add_argument(
+ "--amodel",
+ type=str,
+ default="RN50",
+ help="Name of the audio backbone to use.",
+ )
+ parser.add_argument(
+ "--tmodel",
+ type=str,
+ default="transformer",
+ help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]",
+ )
+ parser.add_argument(
+ "--pretrained-audio",
+ default="",
+ type=str,
+ help="Use a pretrained audio model weights for the audio encoder of CLAP",
+ )
+ parser.add_argument(
+ "--pretrained-text",
+ default="",
+ type=str,
+ help="Use a pretrained text model weights for the text encoder of CLAP",
+ )
+ parser.add_argument(
+ "--pretrained",
+ default="",
+ type=str,
+ help="Use a pretrained CLIP model weights with the specified tag or file path.",
+ )
+ parser.add_argument(
+ "--pretrained-image",
+ default=False,
+ action="store_true",
+ help="Load imagenet pretrained weights for image tower backbone if available.",
+ )
+ parser.add_argument(
+ "--lock-image",
+ default=False,
+ action="store_true",
+ help="Lock full image tower by disabling gradients.",
+ )
+ parser.add_argument(
+ "--lock-image-unlocked-groups",
+ type=int,
+ default=0,
+ help="Leave last n image tower layer groups unlocked.",
+ )
+ parser.add_argument(
+ "--lock-image-freeze-bn-stats",
+ default=False,
+ action="store_true",
+ help="Freeze BatchNorm running stats in image tower for any locked layers.",
+ )
+ parser.add_argument(
+ "--local-loss",
+ default=False,
+ action="store_true",
+ help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)",
+ )
+ parser.add_argument(
+ "--gather-with-grad",
+ default=False,
+ action="store_true",
+ help="enable full distributed gradient for feature gather",
+ )
+ parser.add_argument(
+ "--force-quick-gelu",
+ default=False,
+ action="store_true",
+ help="Force use of QuickGELU activation for non-OpenAI transformer models.",
+ )
+ parser.add_argument(
+ "--torchscript",
+ default=False,
+ action="store_true",
+ help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
+ )
+ parser.add_argument(
+ "--trace",
+ default=False,
+ action="store_true",
+ help="torch.jit.trace the model for inference / eval only",
+ )
+ # arguments for distributed training
+ parser.add_argument(
+ "--dist-url",
+ default="env://",
+ type=str,
+ help="url used to set up distributed training",
+ )
+ parser.add_argument(
+ "--dist-backend", default="nccl", type=str, help="distributed backend"
+ )
+ parser.add_argument(
+ "--report-to",
+ default="",
+ type=str,
+ help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']",
+ )
+ parser.add_argument(
+ "--wandb-notes", default="", type=str, help="Notes if logging with wandb"
+ )
+ parser.add_argument(
+ "--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
+ )
+ parser.add_argument(
+ "--debug",
+ default=False,
+ action="store_true",
+ help="If true, more information is logged.",
+ )
+ parser.add_argument(
+ "--copy-codebase",
+ default=False,
+ action="store_true",
+ help="If true, we copy the entire base on the log diretory, and execute from there.",
+ )
+ parser.add_argument(
+ "--horovod",
+ default=False,
+ action="store_true",
+ help="Use horovod for distributed training.",
+ )
+ parser.add_argument(
+ "--ddp-static-graph",
+ default=False,
+ action="store_true",
+ help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
+ )
+ parser.add_argument(
+ "--no-set-device-rank",
+ default=False,
+ action="store_true",
+ help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
+ )
+ parser.add_argument("--seed", type=int, default=4242, help="Default random seed.")
+
+ parser.add_argument(
+ "--top-k-checkpoint-select-dataset",
+ type=str,
+ default="all",
+ help="The dataset of selecting top-k checkpoint.",
+ )
+
+ # @R10, @R@5, @R1, mAP@10
+ parser.add_argument(
+ "--top-k-checkpoint-select-metric",
+ type=str,
+ default="_R@10",
+ help="The metric for selecting top-k checkpoint.",
+ )
+ parser.add_argument(
+ "--openai-model-cache-dir",
+ type=str,
+ default=f"{CACHE_DIR}/clip",
+ help="Directory to download OpenAI models.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="adamw",
+ help="can be AdamW or SGD",
+ )
+ parser.add_argument(
+ "--parallel-eval",
+ default=False,
+ action="store_true",
+ help="Eval in parallel (multi-GPU, multi-node).",
+ )
+
+ parser.add_argument(
+ "--no-eval",
+ default=False,
+ action="store_true",
+ help="Training without evaluation.",
+ )
+
+ parser.add_argument(
+ "--lp-mlp",
+ default=False,
+ action="store_true",
+ help="Linear Probe using MLP layer or not.",
+ )
+
+ parser.add_argument(
+ "--lp-freeze",
+ default=False,
+ action="store_true",
+ help="Linear Probe using Freeze CLAP or not",
+ )
+
+ parser.add_argument(
+ "--lp-act",
+ default="None",
+ type=str,
+ help="Options are ['relu','elu','prelu','softmax','sigmoid']",
+ )
+
+ parser.add_argument(
+ "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe."
+ )
+
+ parser.add_argument(
+ "--lp-metrics",
+ type=str,
+ default="map,mauc,acc",
+ help="Metrics of Linear Probe.",
+ )
+
+ parser.add_argument(
+ "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe"
+ )
+ parser.add_argument(
+ "--kappa",
+ type=float,
+ default=0,
+ help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss",
+ )
+
+ parser.add_argument(
+ "--data-filling",
+ type=str,
+ default="pad",
+ help="type of data filling when the audio length is shorter than the max length."
+ "Can be one of the following: repeat, repeatpad, pad",
+ )
+ parser.add_argument(
+ "--data-truncating",
+ type=str,
+ default="rand_trunc",
+ help="type of data truncation when the audio length is longer than the max length."
+ "Can be one of the following: rand_trunc, fusion",
+ )
+
+ parser.add_argument(
+ "--clap-mlploss",
+ default=False,
+ action="store_true",
+ help="Using MLP loss for CLAP model or not",
+ )
+
+ parser.add_argument(
+ "--wandb-id",
+ type=str,
+ default=None,
+ help="the id of wandb experiment to restore.",
+ )
+
+ parser.add_argument(
+ "--sleep", type=float, default=0, help="sleep n seconds before start training"
+ )
+
+ # variable length processing
+ parser.add_argument(
+ "--enable-fusion",
+ default=False,
+ action="store_true",
+ help="Enable feature funsion for variable-length data",
+ )
+
+ parser.add_argument(
+ "--fusion-type",
+ type=str,
+ default="None",
+ help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']",
+ )
+
+ parser.add_argument(
+ "--mixup",
+ default=False,
+ action="store_true",
+ help="Enable mixup in finetuning training.",
+ )
+ parser.add_argument(
+ "--text-augment-selection",
+ type=str,
+ default=None,
+ help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']",
+ )
+
+ args = parser.parse_args()
+
+ # If some params are not passed, we use the default values based on model name.
+ default_params = get_default_params(args.amodel)
+ for name, val in default_params.items():
+ if getattr(args, name) is None:
+ setattr(args, name, val)
+
+ return args
diff --git a/audioldm/clap/training/scheduler.py b/audioldm/clap/training/scheduler.py
new file mode 100755
index 0000000000000000000000000000000000000000..7151ffbab25a113673b7627027b443b27f22cb0f
--- /dev/null
+++ b/audioldm/clap/training/scheduler.py
@@ -0,0 +1,24 @@
+import numpy as np
+
+
+def assign_learning_rate(optimizer, new_lr):
+ for param_group in optimizer.param_groups:
+ param_group["lr"] = new_lr
+
+
+def _warmup_lr(base_lr, warmup_length, step):
+ return base_lr * (step + 1) / warmup_length
+
+
+def cosine_lr(optimizer, base_lr, warmup_length, steps):
+ def _lr_adjuster(step):
+ if step < warmup_length:
+ lr = _warmup_lr(base_lr, warmup_length, step)
+ else:
+ e = step - warmup_length
+ es = steps - warmup_length
+ lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
+ assign_learning_rate(optimizer, lr)
+ return lr
+
+ return _lr_adjuster
diff --git a/audioldm/clap/training/train.py b/audioldm/clap/training/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..f5759c4679d2ee9c0748444adf66b8453cf09728
--- /dev/null
+++ b/audioldm/clap/training/train.py
@@ -0,0 +1,838 @@
+import json
+import logging
+import math
+import os
+import time
+from contextlib import suppress
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+try:
+ import wandb
+except ImportError:
+ wandb = None
+
+from open_clip import ClipLoss, gather_features
+from .distributed import is_master
+from .zero_shot import zero_shot_eval
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def unwrap_model(model):
+ if hasattr(model, "module"):
+ return model.module
+ else:
+ return model
+
+
+def train_one_epoch(
+ model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None
+):
+ device = torch.device(args.device)
+ autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
+ model.train()
+ loss = ClipLoss(
+ local_loss=args.local_loss,
+ gather_with_grad=args.gather_with_grad,
+ cache_labels=True,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod,
+ mlp_loss=args.clap_mlploss,
+ weight_loss_kappa=args.kappa,
+ )
+
+ dataloader, sampler = data["train"].dataloader, data["train"].sampler
+ if args.distributed and sampler is not None:
+ sampler.set_epoch(epoch)
+ num_batches_per_epoch = dataloader.num_batches
+ sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))
+
+ # for toy dataset
+ if args.dataset_type == "toy":
+ dataloader.dataset.generate_queue()
+
+ loss_m = AverageMeter()
+ batch_time_m = AverageMeter()
+ data_time_m = AverageMeter()
+ end = time.time()
+
+ for i, batch in enumerate(dataloader):
+ # logging.info(f"batch {i} of {num_batches_per_epoch}")
+ step = num_batches_per_epoch * epoch + i
+ if isinstance(scheduler, dict):
+ for s in scheduler.values():
+ s(step)
+ else:
+ scheduler(step)
+ audios = batch # contains mel_spec, wavform, and longer list
+ texts = batch["text"]
+ # audios = audios.to(device=device, non_blocking=True)
+ # texts = texts.to(device=device, non_blocking=True)
+
+ data_time_m.update(time.time() - end)
+ if isinstance(optimizer, dict):
+ for o_ in optimizer.values():
+ o_.zero_grad()
+ else:
+ optimizer.zero_grad()
+
+ with autocast():
+ (
+ audio_features,
+ text_features,
+ audio_features_mlp,
+ text_features_mlp,
+ logit_scale_a,
+ logit_scale_t,
+ ) = model(audios, texts, device)
+
+ if args.clap_mlploss:
+ total_loss = loss(
+ audio_features=audio_features,
+ text_features=text_features,
+ logit_scale_a=logit_scale_a,
+ logit_scale_t=logit_scale_t,
+ audio_features_mlp=audio_features_mlp,
+ text_features_mlp=text_features_mlp,
+ )
+ else:
+ total_loss = loss(
+ audio_features=audio_features,
+ text_features=text_features,
+ logit_scale_a=logit_scale_a,
+ )
+ if isinstance(optimizer, dict):
+ if scaler is not None:
+ scaler.scale(total_loss).backward()
+ for o_ in optimizer.values():
+ if args.horovod:
+ o_.synchronize()
+ scaler.unscale_(o_)
+ with o_.skip_synchronize():
+ scaler.step(o_)
+ else:
+ scaler.step(o_)
+ scaler.update()
+ else:
+ total_loss.backward()
+ for o_ in optimizer.values():
+ o_.step()
+ else:
+ if scaler is not None:
+ scaler.scale(total_loss).backward()
+ if args.horovod:
+ optimizer.synchronize()
+ scaler.unscale_(optimizer)
+ with optimizer.skip_synchronize():
+ scaler.step(optimizer)
+ else:
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ total_loss.backward()
+ optimizer.step()
+
+ # Note: we clamp to 4.6052 = ln(100), as in the original paper.
+ with torch.no_grad():
+ unwrap_model(model).logit_scale_a.clamp_(0, math.log(100))
+ if args.clap_mlploss:
+ unwrap_model(model).logit_scale_t.clamp_(0, math.log(100))
+
+ batch_time_m.update(time.time() - end)
+ end = time.time()
+ batch_count = i + 1
+ if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch):
+ if isinstance(audios, dict):
+ batch_size = len(audios["waveform"])
+ else:
+ batch_size = len(audios)
+ num_samples = batch_count * batch_size * args.world_size
+ samples_per_epoch = dataloader.num_samples
+ percent_complete = 100.0 * batch_count / num_batches_per_epoch
+
+ # NOTE loss is coarsely sampled, just master node and per log update
+ loss_m.update(total_loss.item(), batch_size)
+ logit_scale_scalar_a = logit_scale_a.item()
+ logit_scale_scalar_t = logit_scale_t.item()
+ if isinstance(optimizer, dict):
+ if args.clap_mlploss:
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f} "
+ f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} "
+ f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
+ f"Logit Scale Text: {logit_scale_scalar_t:.3f}"
+ )
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "scale_audio": logit_scale_scalar_a,
+ "scale_text": logit_scale_scalar_t,
+ "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
+ }
+ else:
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f} "
+ f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} "
+ f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
+ )
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "scale_audio": logit_scale_scalar_a,
+ "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()],
+ }
+
+ else:
+ if args.clap_mlploss:
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f} "
+ f"LR: {optimizer.param_groups[0]['lr']:5f} "
+ f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
+ f"Logit Scale Text: {logit_scale_scalar_t:.3f}"
+ )
+
+ # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "scale_audio": logit_scale_scalar_a,
+ "scale_text": logit_scale_scalar_t,
+ "lr": optimizer.param_groups[0]["lr"],
+ }
+ else:
+ logging.info(
+ f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
+ f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) "
+ f"Data (t): {data_time_m.avg:.3f} "
+ f"Batch (t): {batch_time_m.avg:.3f} "
+ f"LR: {optimizer.param_groups[0]['lr']:5f} "
+ f"Logit Scale Audio: {logit_scale_scalar_a:.3f}"
+ )
+
+ # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
+ log_data = {
+ "loss": loss_m.val,
+ "data_time": data_time_m.val,
+ "batch_time": batch_time_m.val,
+ "scale_audio": logit_scale_scalar_a,
+ "lr": optimizer.param_groups[0]["lr"],
+ }
+ for name, val in log_data.items():
+ name = "train/" + name
+ if tb_writer is not None:
+ tb_writer.add_scalar(name, val, step)
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+ wandb.log({name: val, "step": step})
+
+ # resetting batch / data time meters per log window
+ batch_time_m.reset()
+ data_time_m.reset()
+ # end for
+
+
+def evaluate(model, data, epoch, args, tb_writer=None):
+ metrics = {}
+ if not args.parallel_eval:
+ if not is_master(args):
+ return metrics
+ device = torch.device(args.device)
+ model.eval()
+
+ # CHANGE
+ # zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
+ # metrics.update(zero_shot_metrics)
+ if is_master(args):
+ print("Evaluating...")
+ autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
+ if args.val_dataset_names == ["Clotho", "audiocaps"]:
+ # if only clotho and audiocaps are used, then we will use a different evaluation function.
+ # This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio.
+ if args.parallel_eval:
+ # (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps.
+ raise NotImplementedError(
+ "Parallel evaluation not supported for eval only Clotho and audiocaps."
+ )
+ val_metrics_per_dataset = evaluate_clotho_audiocaps(
+ model, data, epoch, args, autocast, device, tb_writer
+ )
+ for m in val_metrics_per_dataset.values():
+ metrics.update(m)
+ if "epoch" not in metrics.keys():
+ metrics.update({"epoch": epoch})
+ metrics = select_top_metric_clotho_audiocaps(
+ metrics, val_metrics_per_dataset, args
+ )
+ elif "val" in data and (
+ args.val_frequency
+ and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
+ ):
+ dataloader = data["val"].dataloader
+ num_samples = 0
+ samples_per_val = dataloader.num_samples
+
+ # FIXME this does not scale past small eval datasets
+ # all_audio_features @ all_text_features will blow up memory and compute very quickly
+ eval_info = {}
+ if args.clap_mlploss:
+ eval_info["all"] = {
+ "cumulative_loss": 0.0,
+ "num_samples": 0,
+ "all_audio_features": [],
+ "all_text_features": [],
+ "all_audio_features_mlp": [],
+ "all_text_features_mlp": [],
+ } # cumulative_loss = 0.0
+ else:
+ eval_info["all"] = {
+ "cumulative_loss": 0.0,
+ "num_samples": 0,
+ "all_audio_features": [],
+ "all_text_features": [],
+ } # cumu
+ # all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], []
+ with torch.no_grad():
+ for i, batch in enumerate(dataloader):
+ audios = batch # contains mel_spec, wavform, and longer list
+ texts = batch["text"]
+ # audios = audios.to(device=device, non_blocking=True)
+
+ all_names = list(
+ set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
+ )
+ for name in all_names:
+ if name not in eval_info.keys():
+ if args.clap_mlploss:
+ eval_info[name] = {
+ "cumulative_loss": 0.0,
+ "num_samples": 0,
+ "all_audio_features": [],
+ "all_text_features": [],
+ "all_audio_features_mlp": [],
+ "all_text_features_mlp": [],
+ }
+ else:
+ eval_info[name] = {
+ "cumulative_loss": 0.0,
+ "num_samples": 0,
+ "all_audio_features": [],
+ "all_text_features": [],
+ }
+ with autocast():
+ (
+ audio_features,
+ text_features,
+ audio_features_mlp,
+ text_features_mlp,
+ logit_scale_a,
+ logit_scale_t,
+ ) = model(audios, texts, device)
+
+ if args.parallel_eval:
+ # multi-GPU eval
+ if args.clap_mlploss:
+ (
+ audio_features,
+ text_features,
+ audio_features_mlp,
+ text_features_mlp,
+ ) = gather_features(
+ audio_features=audio_features,
+ text_features=text_features,
+ audio_features_mlp=audio_features_mlp,
+ text_features_mlp=text_features_mlp,
+ local_loss=False,
+ gather_with_grad=False,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod,
+ mlp_loss=args.clap_mlploss,
+ )
+ else:
+ (audio_features, text_features,) = gather_features(
+ audio_features=audio_features,
+ text_features=text_features,
+ local_loss=False,
+ gather_with_grad=False,
+ rank=args.rank,
+ world_size=args.world_size,
+ use_horovod=args.horovod,
+ mlp_loss=args.clap_mlploss,
+ )
+
+ if is_master(args):
+ num_samples += audio_features.shape[0]
+ for n in [*all_names, "all"]:
+ if n == "all":
+ eval_info[n]["all_audio_features"].append(
+ audio_features.cpu()
+ )
+ eval_info[n]["all_text_features"].append(
+ text_features.cpu()
+ )
+ if args.clap_mlploss:
+ eval_info[n]["all_audio_features_mlp"].append(
+ audio_features_mlp.cpu()
+ )
+ eval_info[n]["all_text_features_mlp"].append(
+ text_features_mlp.cpu()
+ )
+ else:
+ idx = np.where(
+ np.array(
+ [
+ "-".join(b.split("/")[-3:-1])
+ for b in batch["__url__"]
+ ]
+ )
+ == n
+ )[0]
+ eval_info[n]["all_audio_features"].append(
+ audio_features.cpu().index_select(
+ 0, torch.tensor(idx).long()
+ )
+ )
+ eval_info[n]["all_text_features"].append(
+ text_features.cpu().index_select(
+ 0, torch.tensor(idx).long()
+ )
+ )
+ if args.clap_mlploss:
+ eval_info[n]["all_audio_features_mlp"].append(
+ audio_features_mlp.cpu().index_select(
+ 0, torch.tensor(idx).long()
+ )
+ )
+ eval_info[n]["all_text_features_mlp"].append(
+ text_features_mlp.cpu().index_select(
+ 0, torch.tensor(idx).long()
+ )
+ )
+ # print(f'eval step {i}') # (yusong): for debug
+
+ # cumulative_loss += total_loss * batch_size
+ # num_samples += batch_size
+ if is_master(args) and (i % 100) == 0: # and i != 0:
+ logging.info(
+ f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
+ )
+ if is_master(args):
+ val_metrics_per_dataset = {}
+ for n in eval_info.keys():
+ if args.clap_mlploss:
+ metrics_single_dataset = get_metrics(
+ audio_features=torch.cat(
+ eval_info[n]["all_audio_features"]
+ ),
+ text_features=torch.cat(eval_info[n]["all_text_features"]),
+ logit_scale_a=logit_scale_a.cpu(),
+ audio_features_mlp=torch.cat(
+ eval_info[n]["all_audio_features_mlp"]
+ ),
+ text_features_mlp=torch.cat(
+ eval_info[n]["all_text_features_mlp"]
+ ),
+ logit_scale_t=logit_scale_t.cpu(),
+ mlp_loss=args.clap_mlploss,
+ )
+ else:
+ metrics_single_dataset = get_metrics(
+ audio_features=torch.cat(
+ eval_info[n]["all_audio_features"]
+ ),
+ text_features=torch.cat(eval_info[n]["all_text_features"]),
+ logit_scale_a=logit_scale_a.cpu(),
+ mlp_loss=args.clap_mlploss,
+ )
+ val_metrics_per_dataset[n] = {
+ n + "/" + k: v for k, v in metrics_single_dataset.items()
+ }
+ metrics.update(val_metrics_per_dataset[n])
+ if "epoch" not in metrics.keys():
+ metrics.update({"epoch": epoch})
+ if is_master(args):
+ if not metrics:
+ return metrics
+
+ logging.info(
+ f"Eval Epoch: {epoch} "
+ + "\n".join(
+ [
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()])
+ for m in val_metrics_per_dataset.values()
+ ]
+ )
+ )
+
+ if args.save_logs:
+ for name, val in metrics.items():
+ if tb_writer is not None:
+ tb_writer.add_scalar(f"val/{name}", val, epoch)
+
+ with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
+ f.write(json.dumps(metrics))
+ f.write("\n")
+
+ if args.wandb:
+ assert wandb is not None, "Please install wandb."
+ for name, val in metrics.items():
+ wandb.log({f"val/{name}": val, "epoch": epoch})
+
+ return metrics
+ else:
+ return metrics
+
+
+def get_metrics(
+ audio_features,
+ text_features,
+ logit_scale_a,
+ audio_features_mlp=None,
+ text_features_mlp=None,
+ logit_scale_t=None,
+ mlp_loss=False,
+):
+ metrics = {}
+ if mlp_loss:
+ # Set up audio to text & text to audio similary matrice
+ a_logits_per_audio = (
+ (logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu()
+ )
+ a_logits_per_text = a_logits_per_audio.t().detach().cpu()
+ t_logits_per_audio = (
+ (logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu()
+ )
+ t_logits_per_text = t_logits_per_audio.t().detach().cpu()
+
+ labels = torch.arange(audio_features.shape[0]).long()
+ # Change the loss from two terms into four terms with 2x2 combined CE loss
+ total_loss = (
+ F.cross_entropy(a_logits_per_audio, labels)
+ + F.cross_entropy(a_logits_per_text, labels)
+ + F.cross_entropy(t_logits_per_audio, labels)
+ + F.cross_entropy(t_logits_per_text, labels)
+ ) / 4
+
+ metrics[f"cumulative_loss"] = total_loss.item()
+ metrics[f"num_samples"] = audio_features.shape[0]
+
+ logits = {
+ "audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2,
+ "text_to_audio": (a_logits_per_text + t_logits_per_text) / 2,
+ }
+ ground_truth = torch.arange(len(text_features)).view(-1, 1)
+
+ else:
+ # print("text_features", text_features)
+ # print("text_features.shape", text_features.shape)
+ logits_per_audio = (
+ (logit_scale_a * audio_features @ text_features.t()).detach().cpu()
+ )
+ logits_per_text = logits_per_audio.t().detach().cpu()
+
+ labels = torch.arange(audio_features.shape[0]).long()
+ # Change the loss from two terms into four terms with 2x2 combined CE loss
+ total_loss = (
+ F.cross_entropy(logits_per_audio, labels)
+ + F.cross_entropy(logits_per_text, labels)
+ ) / 2
+
+ metrics[f"cumulative_loss"] = total_loss.item()
+ metrics[f"num_samples"] = audio_features.shape[0]
+
+ logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text}
+
+ ground_truth = torch.arange(len(text_features)).view(-1, 1)
+
+ for name, logit in logits.items():
+ ranking = torch.argsort(logit, descending=True)
+ preds = torch.where(ranking == ground_truth)[
+ 1
+ ] # (yusong) this line is slow because it uses single thread
+ preds = preds.detach().cpu().numpy()
+ metrics[f"{name}_mean_rank"] = preds.mean() + 1
+ metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
+ for k in [1, 5, 10]:
+ metrics[f"{name}_R@{k}"] = np.mean(preds < k)
+ # map@10
+ metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))
+
+ return metrics
+
+
+def evaluate_clotho_audiocaps(
+ model, data, epoch, args, autocast, device, tb_writer=None
+):
+ """
+ Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py.
+ 1. for text-to-audio retrieval, do 5 times and average the results
+ 2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text
+ 3. for map@10 in audio-to-text retrieval:
+ 3.1: sort the rank of 5 text
+ 3.2: exclude the rank >=10 (0-index)
+ 3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks).
+ (3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth.
+ (3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc.
+ """
+ # TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now.
+ dataloader = data["val"].dataloader
+ with torch.no_grad():
+ eval_info = {}
+ for i, batch in enumerate(dataloader):
+ audios = batch # contains mel_spec, wavform, and longer list
+
+ # each item in the list has 5 texts
+ if args.tmodel == "transformer":
+ from open_clip import tokenize
+
+ texts = [tokenize(t) for t in batch["full_text"]]
+ texts = torch.cat(texts)
+ else:
+ from .data import tokenizer
+
+ texts = [
+ tokenizer(t) for t in batch["full_text"]
+ ] # 5 texts for each audio
+ texts = {
+ k: torch.cat([t[k] for t in texts]) for k in texts[0].keys()
+ } # 5 x batch
+
+ # audios = audios.to(device=device, non_blocking=True)
+
+ all_names = list(
+ set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
+ )
+ for name in all_names:
+ if name not in eval_info.keys():
+ # we will not use mlp outputs even if args.clap_mlploss=True
+ eval_info[name] = {
+ "cumulative_loss": 0.0,
+ "num_samples": 0,
+ "all_audio_features": [],
+ "all_text_features": [],
+ }
+ with autocast():
+ audio_features = model(audios, None, device)
+ text_features = model(None, texts, device)
+ audio_features = F.normalize(audio_features, dim=-1)
+ text_features = F.normalize(text_features, dim=-1)
+
+ all_names = list(
+ set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]])
+ )
+ for n in all_names:
+ idx = np.where(
+ np.array(
+ ["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]
+ )
+ == n
+ )[0]
+ eval_info[n]["all_audio_features"].append(
+ audio_features.cpu().index_select(0, torch.tensor(idx).long())
+ )
+ # (yusong) please double-check. This is for selecting 5 text features at once.
+ # because idx is a list of indices in size of num_samples,
+ # and text_features is a tensor of size (5*num_samples, dim)
+ # so we need to select 5 consecutive indices at once for a single index in idx.
+ eval_info[n]["all_text_features"].append(
+ text_features.cpu()
+ .reshape([-1, 5, text_features.shape[1]])
+ .index_select(0, torch.tensor(idx).long())
+ .reshape([-1, text_features.shape[1]])
+ )
+
+ val_metrics_all = {}
+
+ for n in eval_info.keys():
+ logit_scale_a, logit_scale_t = model(None, None, device)
+ logit_scale_a = logit_scale_a.cpu()
+
+ audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0)
+ text_features = torch.cat(eval_info[n]["all_text_features"], dim=0)
+
+ logits_per_audio = (
+ (logit_scale_a * audio_features @ text_features.t()).detach().cpu()
+ )
+ logits_per_text = logits_per_audio.t().detach().cpu()
+
+ # logits_per_audio shape: [num_samples, num_samples*5]
+ # logits_per_text shape: [num_samples*5, num_samples]
+
+ logging.info(
+ f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, "
+ f"logits_per_text shape: {logits_per_text.shape}"
+ )
+
+ metrics = {}
+ num_samples = audio_features.shape[0]
+ metrics[f"num_samples"] = num_samples
+
+ # (yusong) the following code is very important, please double-check:
+ # logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d]
+ # logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :]
+ # Those two are retrieving one of the 5 text for each audio.
+ labels = torch.arange(audio_features.shape[0]).long()
+ audio_to_text_loss = [
+ F.cross_entropy(
+ logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d],
+ labels,
+ )
+ for d in range(5)
+ ]
+ text_to_audio_loss = [
+ F.cross_entropy(
+ logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :],
+ labels,
+ )
+ for d in range(5)
+ ]
+ total_loss = (np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)) / 2
+
+ metrics[f"cumulative_loss"] = total_loss.item()
+
+ # text to audio: do 5 times
+ pred_text = []
+ for d in range(5):
+ logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :]
+ ground_truth = torch.arange(len(logit)).view(-1, 1)
+ ranking = torch.argsort(
+ logit, descending=True
+ ) # [num_samples, num_samples]
+ preds = torch.where(ranking == ground_truth)[1]
+ pred_text.append(preds.detach().cpu().numpy())
+ pred_text_concat = np.concatenate(pred_text, axis=0) # [5*num_samples]
+ metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1
+ metrics[f"text_to_audio_median_rank"] = (
+ np.floor(np.median(pred_text_concat)) + 1
+ )
+ for k in [1, 5, 10]:
+ metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k)
+ # map@10
+ metrics[f"text_to_audio_mAP@10"] = np.mean(
+ np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0)
+ )
+
+ # audio to text: take the best result
+ # for audio to text map 10, sort and assign descending ground truth.
+ # see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103
+ # map@10
+ map_all = []
+ pred_audio_all = []
+ for d in range(num_samples):
+ # logits_per_audio: [num_samples, num_samples*5]
+ logit_single = logits_per_audio[d, :] # [5*num_samples]
+ # Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4]
+ ranking = torch.argsort(
+ logit_single, descending=True
+ ) # [5*num_samples]
+ # ranking: the index of first match, second match, ...
+ ground_truth = torch.arange(d * 5, d * 5 + 5)[None]
+ all_pred = torch.where(
+ torch.stack([ranking] * 5) == ground_truth.view(-1, 1)
+ )[1]
+ min_pred = torch.min(all_pred)
+ pred_audio_all.append(min_pred.detach().cpu().numpy())
+ all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy()
+ # /5 because we have 5 text, so it means for the text rank >=10 we count as 0.
+ map_single = (
+ np.sum(
+ (np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1))
+ )
+ / 5
+ )
+ map_all.append(map_single)
+ metrics[f"audio_to_text_mAP@10"] = np.mean(map_all)
+ for k in [1, 5, 10]:
+ metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k)
+
+ val_metrics_all[n] = {n + "/" + k: v for k, v in metrics.items()}
+ return val_metrics_all
+
+
+def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset):
+ """
+ Calculate performance for Clotho+AudioCaps for model selection.
+ """
+ selection_performance_all = []
+ for n in val_metrics_per_dataset.keys():
+ selection_performance = (
+ val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"]
+ + val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"]
+ ) / 2
+ selection_performance_all.append(selection_performance)
+ return np.mean(selection_performance_all)
+
+
+def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args):
+ # val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value
+ # metrics: dict, key: metric name, value: metric value
+ # Hack: use args to save the top performance
+ if not hasattr(args, "top_selection_performance"):
+ selection_performance = calculate_selection_performance_clotho_audiocaps(
+ val_metrics_per_dataset
+ )
+ # TODO: write the if and else together
+ metric_update = {}
+ for n in val_metrics_per_dataset.keys():
+ for k in val_metrics_per_dataset[n].keys():
+ metric_update[
+ k.split("/")[0] + "-top" + "/" + k.split("/")[1]
+ ] = val_metrics_per_dataset[n][k]
+ metric_update["top_selection_performance"] = selection_performance
+ metric_update["top-selection-epoch"] = metrics["epoch"]
+ metrics.update(metric_update)
+ args.top_metric = metric_update
+ args.top_selection_performance = selection_performance
+ else:
+ selection_performance_new = calculate_selection_performance_clotho_audiocaps(
+ val_metrics_per_dataset
+ )
+ selection_performance_old = args.top_selection_performance
+ if selection_performance_new > selection_performance_old:
+ metric_update = {}
+ for n in val_metrics_per_dataset.keys():
+ for k in val_metrics_per_dataset[n].keys():
+ metric_update[
+ k.split("/")[0] + "-top" + "/" + k.split("/")[1]
+ ] = val_metrics_per_dataset[n][k]
+ metric_update["top_selection_performance"] = selection_performance_new
+ metric_update["top-selection-epoch"] = metrics["epoch"]
+ metrics.update(metric_update)
+ args.top_metric = metric_update
+ args.top_selection_performance = selection_performance_new
+ else:
+ metrics.update(args.top_metric)
+ return metrics
diff --git a/audioldm/clap/training/zero_shot.py b/audioldm/clap/training/zero_shot.py
new file mode 100755
index 0000000000000000000000000000000000000000..28b8fccc1af17fc69002857a7f529ac041c374f2
--- /dev/null
+++ b/audioldm/clap/training/zero_shot.py
@@ -0,0 +1,95 @@
+# NOTE: This script is currently not supported for CLAP.
+import logging
+from contextlib import suppress
+
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm
+
+from open_clip import tokenize
+from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template
+
+
+def zero_shot_classifier(model, classnames, templates, args):
+ with torch.no_grad():
+ zeroshot_weights = []
+ for classname in tqdm(classnames):
+ texts = [template(classname) for template in templates] # format with class
+ texts = tokenize(texts).to(args.device) # tokenize
+ if args.distributed and not args.horovod:
+ class_embeddings = model.module.encode_text(texts)
+ else:
+ class_embeddings = model.encode_text(texts)
+ class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
+ class_embedding /= class_embedding.norm()
+ zeroshot_weights.append(class_embedding)
+ zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device)
+ return zeroshot_weights
+
+
+def accuracy(output, target, topk=(1,)):
+ pred = output.topk(max(topk), 1, True, True)[1].t()
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
+ return [
+ float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
+ for k in topk
+ ]
+
+
+def run(model, classifier, dataloader, args):
+ autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
+ with torch.no_grad():
+ top1, top5, n = 0.0, 0.0, 0.0
+ for images, target in tqdm(dataloader, unit_scale=args.batch_size):
+ images = images.to(args.device)
+ target = target.to(args.device)
+
+ with autocast():
+ # predict
+ if args.distributed and not args.horovod:
+ image_features = model.module.encode_image(images)
+ else:
+ image_features = model.encode_image(images)
+ image_features = F.normalize(image_features, dim=-1)
+ logits = 100.0 * image_features @ classifier
+
+ # measure accuracy
+ acc1, acc5 = accuracy(logits, target, topk=(1, 5))
+ top1 += acc1
+ top5 += acc5
+ n += images.size(0)
+
+ top1 = top1 / n
+ top5 = top5 / n
+ return top1, top5
+
+
+def zero_shot_eval(model, data, epoch, args):
+ if "imagenet-val" not in data and "imagenet-v2" not in data:
+ return {}
+ if args.zeroshot_frequency == 0:
+ return {}
+ if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs:
+ return {}
+
+ logging.info("Starting zero-shot imagenet.")
+
+ logging.info("Building zero-shot classifier")
+ classifier = zero_shot_classifier(
+ model, imagenet_classnames, openai_imagenet_template, args
+ )
+
+ logging.info("Using classifier")
+ results = {}
+ if "imagenet-val" in data:
+ top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args)
+ results["imagenet-zeroshot-val-top1"] = top1
+ results["imagenet-zeroshot-val-top5"] = top5
+ if "imagenet-v2" in data:
+ top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args)
+ results["imagenetv2-zeroshot-val-top1"] = top1
+ results["imagenetv2-zeroshot-val-top5"] = top5
+
+ logging.info("Finished zero-shot imagenet.")
+
+ return results
diff --git a/audioldm/hifigan/__init__.py b/audioldm/hifigan/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e0ae476fe58c48e998c56234a55b871beba4042d
--- /dev/null
+++ b/audioldm/hifigan/__init__.py
@@ -0,0 +1,7 @@
+from .models import Generator
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
diff --git a/audioldm/hifigan/models.py b/audioldm/hifigan/models.py
new file mode 100755
index 0000000000000000000000000000000000000000..c4382cc39de0463f9b7c0f33f037dbc233e7cb36
--- /dev/null
+++ b/audioldm/hifigan/models.py
@@ -0,0 +1,174 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+LRELU_SLOPE = 0.1
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+class ResBlock(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock, self).__init__()
+ self.h = h
+ self.convs1 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList(
+ [
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ weight_norm(
+ Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class Generator(torch.nn.Module):
+ def __init__(self, h):
+ super(Generator, self).__init__()
+ self.h = h
+ self.num_kernels = len(h.resblock_kernel_sizes)
+ self.num_upsamples = len(h.upsample_rates)
+ self.conv_pre = weight_norm(
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
+ )
+ resblock = ResBlock
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+ self.ups.append(
+ weight_norm(
+ ConvTranspose1d(
+ h.upsample_initial_channel // (2**i),
+ h.upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
+ for j, (k, d) in enumerate(
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
+ ):
+ self.resblocks.append(resblock(h, ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ # print("Removing weight norm...")
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
diff --git a/audioldm/hifigan/utilities.py b/audioldm/hifigan/utilities.py
new file mode 100755
index 0000000000000000000000000000000000000000..47fd39ea0af181772d640feec2413cf631a75702
--- /dev/null
+++ b/audioldm/hifigan/utilities.py
@@ -0,0 +1,85 @@
+import os
+import json
+
+import torch
+import numpy as np
+
+import audioldm.hifigan as hifigan
+
+HIFIGAN_16K_64 = {
+ "resblock": "1",
+ "num_gpus": 6,
+ "batch_size": 16,
+ "learning_rate": 0.0002,
+ "adam_b1": 0.8,
+ "adam_b2": 0.99,
+ "lr_decay": 0.999,
+ "seed": 1234,
+ "upsample_rates": [5, 4, 2, 2, 2],
+ "upsample_kernel_sizes": [16, 16, 8, 4, 4],
+ "upsample_initial_channel": 1024,
+ "resblock_kernel_sizes": [3, 7, 11],
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+ "segment_size": 8192,
+ "num_mels": 64,
+ "num_freq": 1025,
+ "n_fft": 1024,
+ "hop_size": 160,
+ "win_size": 1024,
+ "sampling_rate": 16000,
+ "fmin": 0,
+ "fmax": 8000,
+ "fmax_for_loss": None,
+ "num_workers": 4,
+ "dist_config": {
+ "dist_backend": "nccl",
+ "dist_url": "tcp://localhost:54321",
+ "world_size": 1,
+ },
+}
+
+
+def get_available_checkpoint_keys(model, ckpt):
+ print("==> Attemp to reload from %s" % ckpt)
+ state_dict = torch.load(ckpt)["state_dict"]
+ current_state_dict = model.state_dict()
+ new_state_dict = {}
+ for k in state_dict.keys():
+ if (
+ k in current_state_dict.keys()
+ and current_state_dict[k].size() == state_dict[k].size()
+ ):
+ new_state_dict[k] = state_dict[k]
+ else:
+ print("==> WARNING: Skipping %s" % k)
+ print(
+ "%s out of %s keys are matched"
+ % (len(new_state_dict.keys()), len(state_dict.keys()))
+ )
+ return new_state_dict
+
+
+def get_param_num(model):
+ num_param = sum(param.numel() for param in model.parameters())
+ return num_param
+
+
+def get_vocoder(config, device):
+ config = hifigan.AttrDict(HIFIGAN_16K_64)
+ vocoder = hifigan.Generator(config)
+ vocoder.eval()
+ vocoder.remove_weight_norm()
+ vocoder.to(device)
+ return vocoder
+
+
+def vocoder_infer(mels, vocoder, lengths=None):
+ with torch.no_grad():
+ wavs = vocoder(mels).squeeze(1)
+
+ wavs = (wavs.cpu().numpy() * 32768).astype("int16")
+
+ if lengths is not None:
+ wavs = wavs[:, :lengths]
+
+ return wavs
diff --git a/audioldm/latent_diffusion/__init__.py b/audioldm/latent_diffusion/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm/latent_diffusion/attention.py b/audioldm/latent_diffusion/attention.py
new file mode 100755
index 0000000000000000000000000000000000000000..27886f5ee3c7eb856100503b838399106ef00051
--- /dev/null
+++ b/audioldm/latent_diffusion/attention.py
@@ -0,0 +1,469 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn
+from einops import rearrange
+
+from audioldm.latent_diffusion.util import checkpoint
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
+ )
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
+ )
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b (h w) c")
+ k = rearrange(k, "b c h w -> b c (h w)")
+ w_ = torch.einsum("bij,bjk->bik", q, k)
+
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, "b c h w -> b c (h w)")
+ w_ = rearrange(w_, "b i j -> b j i")
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class CrossAttention(nn.Module):
+ """
+ ### Cross Attention Layer
+ This falls-back to self-attention when conditional embeddings are not specified.
+ """
+
+ # use_flash_attention: bool = True
+ use_flash_attention: bool = False
+
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ is_inplace: bool = True,
+ ):
+ # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
+ """
+ :param d_model: is the input embedding size
+ :param n_heads: is the number of attention heads
+ :param d_head: is the size of a attention head
+ :param d_cond: is the size of the conditional embeddings
+ :param is_inplace: specifies whether to perform the attention softmax computation inplace to
+ save memory
+ """
+ super().__init__()
+
+ self.is_inplace = is_inplace
+ self.n_heads = heads
+ self.d_head = dim_head
+
+ # Attention scaling factor
+ self.scale = dim_head**-0.5
+
+ # The normal self-attention layer
+ if context_dim is None:
+ context_dim = query_dim
+
+ # Query, key and value mappings
+ d_attn = dim_head * heads
+ self.to_q = nn.Linear(query_dim, d_attn, bias=False)
+ self.to_k = nn.Linear(context_dim, d_attn, bias=False)
+ self.to_v = nn.Linear(context_dim, d_attn, bias=False)
+
+ # Final linear layer
+ self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
+
+ # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
+ # Flash attention is only used if it's installed
+ # and `CrossAttention.use_flash_attention` is set to `True`.
+ try:
+ # You can install flash attention by cloning their Github repo,
+ # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
+ # and then running `python setup.py install`
+ from flash_attn.flash_attention import FlashAttention
+
+ self.flash = FlashAttention()
+ # Set the scale for scaled dot-product attention.
+ self.flash.softmax_scale = self.scale
+ # Set to `None` if it's not installed
+ except ImportError:
+ self.flash = None
+
+ def forward(self, x, context=None, mask=None):
+ """
+ :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
+ :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
+ """
+
+ # If `cond` is `None` we perform self attention
+ has_cond = context is not None
+ if not has_cond:
+ context = x
+
+ # Get query, key and value vectors
+ q = self.to_q(x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ # Use flash attention if it's available and the head size is less than or equal to `128`
+ if (
+ CrossAttention.use_flash_attention
+ and self.flash is not None
+ and not has_cond
+ and self.d_head <= 128
+ ):
+ return self.flash_attention(q, k, v)
+ # Otherwise, fallback to normal attention
+ else:
+ return self.normal_attention(q, k, v)
+
+ def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+ """
+ #### Flash Attention
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+ """
+
+ # Get batch size and number of elements along sequence axis (`width * height`)
+ batch_size, seq_len, _ = q.shape
+
+ # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
+ # shape `[batch_size, seq_len, 3, n_heads * d_head]`
+ qkv = torch.stack((q, k, v), dim=2)
+ # Split the heads
+ qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
+
+ # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
+ # fit this size.
+ if self.d_head <= 32:
+ pad = 32 - self.d_head
+ elif self.d_head <= 64:
+ pad = 64 - self.d_head
+ elif self.d_head <= 128:
+ pad = 128 - self.d_head
+ else:
+ raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
+
+ # Pad the heads
+ if pad:
+ qkv = torch.cat(
+ (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
+ )
+
+ # Compute attention
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
+ # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
+ # TODO here I add the dtype changing
+ out, _ = self.flash(qkv.type(torch.float16))
+ # Truncate the extra head size
+ out = out[:, :, :, : self.d_head].float()
+ # Reshape to `[batch_size, seq_len, n_heads * d_head]`
+ out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
+
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
+ return self.to_out(out)
+
+ def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
+ """
+ #### Normal Attention
+
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
+ """
+
+ # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
+ q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32]
+ k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32]
+ v = v.view(*v.shape[:2], self.n_heads, -1)
+
+ # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
+ attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
+
+ # Compute softmax
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
+ if self.is_inplace:
+ half = attn.shape[0] // 2
+ attn[half:] = attn[half:].softmax(dim=-1)
+ attn[:half] = attn[:half].softmax(dim=-1)
+ else:
+ attn = attn.softmax(dim=-1)
+
+ # Compute attention output
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
+ # attn: [bs, 20, 64, 1]
+ # v: [bs, 1, 20, 32]
+ out = torch.einsum("bhij,bjhd->bihd", attn, v)
+ # Reshape to `[batch_size, height * width, n_heads * d_head]`
+ out = out.reshape(*out.shape[:2], -1)
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
+ return self.to_out(out)
+
+
+# class CrossAttention(nn.Module):
+# def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+# super().__init__()
+# inner_dim = dim_head * heads
+# context_dim = default(context_dim, query_dim)
+
+# self.scale = dim_head ** -0.5
+# self.heads = heads
+
+# self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+# self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+# self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+# self.to_out = nn.Sequential(
+# nn.Linear(inner_dim, query_dim),
+# nn.Dropout(dropout)
+# )
+
+# def forward(self, x, context=None, mask=None):
+# h = self.heads
+
+# q = self.to_q(x)
+# context = default(context, x)
+# k = self.to_k(context)
+# v = self.to_v(context)
+
+# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+# sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+# if exists(mask):
+# mask = rearrange(mask, 'b ... -> b (...)')
+# max_neg_value = -torch.finfo(sim.dtype).max
+# mask = repeat(mask, 'b j -> (b h) () j', h=h)
+# sim.masked_fill_(~mask, max_neg_value)
+
+# # attention, what we cannot get enough of
+# attn = sim.softmax(dim=-1)
+
+# out = einsum('b i j, b j d -> b i d', attn, v)
+# out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+# return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ ):
+ super().__init__()
+ self.attn1 = CrossAttention(
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ if context is None:
+ return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
+ else:
+ return checkpoint(
+ self._forward, (x, context), self.parameters(), self.checkpoint
+ )
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x)) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ no_context=False,
+ ):
+ super().__init__()
+
+ if no_context:
+ context_dim = None
+
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
+ )
+ for d in range(depth)
+ ]
+ )
+
+ self.proj_out = zero_module(
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c")
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ x = self.proj_out(x)
+ return x + x_in
diff --git a/audioldm/latent_diffusion/ddim.py b/audioldm/latent_diffusion/ddim.py
new file mode 100755
index 0000000000000000000000000000000000000000..53c3490f5c894e2b80e09d50e7f0b367691ea2a9
--- /dev/null
+++ b/audioldm/latent_diffusion/ddim.py
@@ -0,0 +1,377 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from audioldm.latent_diffusion.util import (
+ make_ddim_sampling_parameters,
+ make_ddim_timesteps,
+ noise_like,
+ extract_into_tensor,
+)
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
+ ):
+ self.ddim_timesteps = make_ddim_timesteps(
+ ddim_discr_method=ddim_discretize,
+ num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
+ verbose=verbose,
+ )
+ alphas_cumprod = self.model.alphas_cumprod
+ assert (
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
+ ), "alphas have to be defined for each timestep"
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer("betas", to_torch(self.model.betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer(
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
+ )
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer(
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod",
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod",
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
+ )
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+ alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,
+ verbose=verbose,
+ )
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
+ self.register_buffer("ddim_alphas", ddim_alphas)
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev)
+ / (1 - self.alphas_cumprod)
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
+ )
+ self.register_buffer(
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
+ )
+
+ @torch.no_grad()
+ def sample(
+ self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.0,
+ mask=None,
+ x0=None,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs,
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+ )
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
+ )
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ samples, intermediates = self.ddim_sampling(
+ conditioning,
+ size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask,
+ x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(
+ self,
+ cond,
+ shape,
+ x_T=None,
+ ddim_use_original_steps=False,
+ callback=None,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ img_callback=None,
+ log_every_t=100,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ ):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = (
+ self.ddpm_num_timesteps
+ if ddim_use_original_steps
+ else self.ddim_timesteps
+ )
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = (
+ int(
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
+ * self.ddim_timesteps.shape[0]
+ )
+ - 1
+ )
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
+ time_range = (
+ reversed(range(0, timesteps))
+ if ddim_use_original_steps
+ else np.flip(timesteps)
+ )
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ # print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ # iterator = gr.Progress().tqdm(time_range, desc="DDIM Sampler", total=total_steps)
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(
+ x0, ts
+ ) # TODO deterministic forward pass?
+ img = (
+ img_orig * mask + (1.0 - mask) * img
+ ) # In the first sampling step, img is pure gaussian noise
+
+ outs = self.p_sample_ddim(
+ img,
+ cond,
+ ts,
+ index=index,
+ use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised,
+ temperature=temperature,
+ noise_dropout=noise_dropout,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ img, pred_x0 = outs
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates["x_inter"].append(img)
+ intermediates["pred_x0"].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+
+ return (
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
+ )
+
+ @torch.no_grad()
+ def decode(
+ self,
+ x_latent,
+ cond,
+ t_start,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ use_original_steps=False,
+ ):
+
+ timesteps = (
+ np.arange(self.ddpm_num_timesteps)
+ if use_original_steps
+ else self.ddim_timesteps
+ )
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ # print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ # iterator = gr.Progress().tqdm(time_range, desc="Decoding image", total=total_steps)
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
+ x_dec = x_latent
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full(
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
+ )
+ x_dec, _ = self.p_sample_ddim(
+ x_dec,
+ cond,
+ ts,
+ index=index,
+ use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
+ return x_dec
+
+ @torch.no_grad()
+ def p_sample_ddim(
+ self,
+ x,
+ c,
+ t,
+ index,
+ repeat_noise=False,
+ use_original_steps=False,
+ quantize_denoised=False,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ ):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ # When unconditional_guidance_scale == 1: only e_t
+ # When unconditional_guidance_scale == 0: only unconditional
+ # When unconditional_guidance_scale > 1: add more unconditional guidance
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(
+ self.model, e_t, x, t, c, **corrector_kwargs
+ )
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = (
+ self.model.alphas_cumprod_prev
+ if use_original_steps
+ else self.ddim_alphas_prev
+ )
+ sqrt_one_minus_alphas = (
+ self.model.sqrt_one_minus_alphas_cumprod
+ if use_original_steps
+ else self.ddim_sqrt_one_minus_alphas
+ )
+ sigmas = (
+ self.model.ddim_sigmas_for_original_num_steps
+ if use_original_steps
+ else self.ddim_sigmas
+ )
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full(
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
+ )
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.0:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise # TODO
+ return x_prev, pred_x0
diff --git a/audioldm/latent_diffusion/ddpm.py b/audioldm/latent_diffusion/ddpm.py
new file mode 100755
index 0000000000000000000000000000000000000000..ff9e0d362744db534bd8f019fe9d647e72266649
--- /dev/null
+++ b/audioldm/latent_diffusion/ddpm.py
@@ -0,0 +1,442 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+import sys
+import os
+
+import torch
+import torch.nn as nn
+import numpy as np
+from contextlib import contextmanager
+from functools import partial
+from tqdm import tqdm
+
+from audioldm.utils import exists, default, count_params, instantiate_from_config
+from audioldm.latent_diffusion.ema import LitEma
+from audioldm.latent_diffusion.util import (
+ make_beta_schedule,
+ extract_into_tensor,
+ noise_like,
+)
+
+import soundfile as sf
+import os
+
+
+__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DiffusionWrapper(nn.Module):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [
+ None,
+ "concat",
+ "crossattn",
+ "hybrid",
+ "adm",
+ "film",
+ ]
+
+ def forward(
+ self, x, t, c_concat: list = None, c_crossattn: list = None, c_film: list = None
+ ):
+ x = x.contiguous()
+ t = t.contiguous()
+
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == "concat":
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == "crossattn":
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == "hybrid":
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif (
+ self.conditioning_key == "film"
+ ): # The condition is assumed to be a global token, which wil pass through a linear layer and added with the time embedding for the FILM
+ cc = c_film[0].squeeze(1) # only has one token
+ out = self.diffusion_model(x, t, y=cc)
+ elif self.conditioning_key == "adm":
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+
+ return out
+
+
+class DDPM(nn.Module):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(
+ self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ latent_t_size=256,
+ latent_f_size=16,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.0,
+ v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.0,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.0,
+ ):
+ super().__init__()
+ assert parameterization in [
+ "eps",
+ "x0",
+ ], 'currently only supporting "eps" and "x0"'
+ self.parameterization = parameterization
+ self.state = None
+ # print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+
+ self.latent_t_size = latent_t_size
+ self.latent_f_size = latent_f_size
+
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+
+ self.register_schedule(
+ given_betas=given_betas,
+ beta_schedule=beta_schedule,
+ timesteps=timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s,
+ )
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+ else:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=False)
+
+ self.logger_save_dir = None
+ self.logger_project = None
+ self.logger_version = None
+ self.label_indices_total = None
+ # To avoid the system cannot find metric value for checkpoint
+ self.metrics_buffer = {
+ "val/kullback_leibler_divergence_sigmoid": 15.0,
+ "val/kullback_leibler_divergence_softmax": 10.0,
+ "val/psnr": 0.0,
+ "val/ssim": 0.0,
+ "val/inception_score_mean": 1.0,
+ "val/inception_score_std": 0.0,
+ "val/kernel_inception_distance_mean": 0.0,
+ "val/kernel_inception_distance_std": 0.0,
+ "val/frechet_inception_distance": 133.0,
+ "val/frechet_audio_distance": 32.0,
+ }
+ self.initial_learning_rate = None
+
+ def get_log_dir(self):
+ if (
+ self.logger_save_dir is None
+ and self.logger_project is None
+ and self.logger_version is None
+ ):
+ return os.path.join(
+ self.logger.save_dir, self.logger._project, self.logger.version
+ )
+ else:
+ return os.path.join(
+ self.logger_save_dir, self.logger_project, self.logger_version
+ )
+
+ def set_log_dir(self, save_dir, project, version):
+ self.logger_save_dir = save_dir
+ self.logger_project = project
+ self.logger_version = version
+
+ def register_schedule(
+ self,
+ given_betas=None,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ ):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(
+ beta_schedule,
+ timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s,
+ )
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+
+ (timesteps,) = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert (
+ alphas_cumprod.shape[0] == self.num_timesteps
+ ), "alphas have to be defined for each timestep"
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer("betas", to_torch(betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
+ )
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (
+ 1.0 - alphas_cumprod_prev
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer(
+ "posterior_log_variance_clipped",
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
+ )
+ self.register_buffer(
+ "posterior_mean_coef1",
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
+ )
+ self.register_buffer(
+ "posterior_mean_coef2",
+ to_torch(
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
+ ),
+ )
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas**2 / (
+ 2
+ * self.posterior_variance
+ * to_torch(alphas)
+ * (1 - self.alphas_cumprod)
+ )
+ elif self.parameterization == "x0":
+ lvlb_weights = (
+ 0.5
+ * np.sqrt(torch.Tensor(alphas_cumprod))
+ / (2.0 * 1 - torch.Tensor(alphas_cumprod))
+ )
+ else:
+ raise NotImplementedError("mu not supported")
+ # TODO how to choose this term
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ # print(f"{context}: Switched to EMA weights")
+ pass
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ # print(f"{context}: Restored training weights")
+ pass
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
+ )
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+ * noise
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1.0, 1.0)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+ x_start=x_recon, x_t=x, t=t
+ )
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(
+ x=x, t=t, clip_denoised=clip_denoised
+ )
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (
+ (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
+ )
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(
+ reversed(range(0, self.num_timesteps)),
+ desc="Sampling t",
+ total=self.num_timesteps,
+ ):
+ img = self.p_sample(
+ img,
+ torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised,
+ )
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ shape = (batch_size, channels, self.latent_t_size, self.latent_f_size)
+ channels = self.channels
+ return self.p_sample_loop(shape, return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def forward(self, x, *args, **kwargs):
+ t = torch.randint(
+ 0, self.num_timesteps, (x.shape[0],), device=self.device
+ ).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch
+ fbank, log_magnitudes_stft, label_indices, fname, waveform, text = batch
+ ret = {}
+
+ ret["fbank"] = (
+ fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
+ )
+ ret["stft"] = log_magnitudes_stft.to(
+ memory_format=torch.contiguous_format
+ ).float()
+ # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float()
+ ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
+ ret["text"] = list(text)
+ ret["fname"] = fname
+
+ return ret[k]
diff --git a/audioldm/latent_diffusion/ema.py b/audioldm/latent_diffusion/ema.py
new file mode 100755
index 0000000000000000000000000000000000000000..880ca3d205d9b4d7450e146930a93f2e63c58b70
--- /dev/null
+++ b/audioldm/latent_diffusion/ema.py
@@ -0,0 +1,82 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+
+ self.m_name2s_name = {}
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer(
+ "num_updates",
+ torch.tensor(0, dtype=torch.int)
+ if use_num_upates
+ else torch.tensor(-1, dtype=torch.int),
+ )
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace(".", "")
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(
+ one_minus_decay * (shadow_params[sname] - m_param[key])
+ )
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/audioldm/latent_diffusion/openaimodel.py b/audioldm/latent_diffusion/openaimodel.py
new file mode 100755
index 0000000000000000000000000000000000000000..831d7aafb36bba16888e4389153979a6c13639f5
--- /dev/null
+++ b/audioldm/latent_diffusion/openaimodel.py
@@ -0,0 +1,1069 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from audioldm.latent_diffusion.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from audioldm.latent_diffusion.attention import SpatialTransformer
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
+ )
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1).contiguous() # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(
+ dims, self.channels, self.out_channels, 3, padding=padding
+ )
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class TransposedUpsample(nn.Module):
+ "Learned 2x upsampling without padding"
+
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(
+ self.channels, self.out_channels, kernel_size=ks, stride=2
+ )
+
+ def forward(self, x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding,
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(
+ self._forward, (x,), self.parameters(), True
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ # return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1).contiguous()
+ qkv = self.qkv(self.norm(x)).contiguous()
+ h = self.attention(qkv).contiguous()
+ h = self.proj_out(h).contiguous()
+ return (x + h).reshape(b, c, *spatial).contiguous()
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial**2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = (
+ qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1)
+ )
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length).contiguous()
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum(
+ "bts,bcs->bct",
+ weight,
+ v.reshape(bs * self.n_heads, ch, length).contiguous(),
+ )
+ return a.reshape(bs, -1, length).contiguous()
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ extra_film_condition_dim=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ extra_film_use_concat=False, # If true, concatenate extrafilm condition with time embedding, else addition
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ ):
+ super().__init__()
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert (
+ num_head_channels != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ if num_head_channels == -1:
+ assert (
+ num_heads != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.extra_film_condition_dim = extra_film_condition_dim
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+ self.extra_film_use_concat = extra_film_use_concat
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ assert not (
+ self.num_classes is not None and self.extra_film_condition_dim is not None
+ ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim."
+
+ if self.num_classes is not None:
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+
+ self.use_extra_film_by_concat = (
+ self.extra_film_condition_dim is not None and self.extra_film_use_concat
+ )
+ self.use_extra_film_by_addition = (
+ self.extra_film_condition_dim is not None and not self.extra_film_use_concat
+ )
+
+ if self.extra_film_condition_dim is not None:
+ self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim)
+ # print("+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " % self.extra_film_condition_dim)
+ # if(self.use_extra_film_by_concat):
+ # print("\t By concatenation with time embedding")
+ # elif(self.use_extra_film_by_concat):
+ # print("\t By addition with time embedding")
+
+ if use_spatial_transformer and (
+ self.use_extra_film_by_concat or self.use_extra_film_by_addition
+ ):
+ # print("+ Spatial transformer will only be used as self-attention. Because you have choose to use film as your global condition.")
+ spatial_transformer_no_context = True
+ else:
+ spatial_transformer_no_context = False
+
+ if use_spatial_transformer and not spatial_transformer_no_context:
+ assert (
+ context_dim is not None
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
+
+ if context_dim is not None and not spatial_transformer_no_context:
+ assert (
+ use_spatial_transformer
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
+ from omegaconf.listconfig import ListConfig
+
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim
+ if (not self.use_extra_film_by_concat)
+ else time_embed_dim * 2,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ if not use_spatial_transformer
+ else SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ no_context=spatial_transformer_no_context,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim
+ if (not self.use_extra_film_by_concat)
+ else time_embed_dim * 2,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim
+ if (not self.use_extra_film_by_concat)
+ else time_embed_dim * 2,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ if not use_spatial_transformer
+ else SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ no_context=spatial_transformer_no_context,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim
+ if (not self.use_extra_film_by_concat)
+ else time_embed_dim * 2,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim
+ if (not self.use_extra_film_by_concat)
+ else time_embed_dim * 2,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ # num_heads = 1
+ dim_head = (
+ ch // num_heads
+ if use_spatial_transformer
+ else num_head_channels
+ )
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ )
+ if not use_spatial_transformer
+ else SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ no_context=spatial_transformer_no_context,
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim
+ if (not self.use_extra_film_by_concat)
+ else time_embed_dim * 2,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ self.shape_reported = False
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ if not self.shape_reported:
+ # print("The shape of UNet input is", x.size())
+ self.shape_reported = True
+
+ assert (y is not None) == (
+ self.num_classes is not None or self.extra_film_condition_dim is not None
+ ), "must specify y if and only if the model is class-conditional or film embedding conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape == (x.shape[0],)
+ emb = emb + self.label_emb(y)
+
+ if self.use_extra_film_by_addition:
+ emb = emb + self.film_emb(y)
+ elif self.use_extra_film_by_concat:
+ emb = th.cat([emb, self.film_emb(y)], dim=-1)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
+
+
+class EncoderUNetModel(nn.Module):
+ """
+ The half UNet model with attention and timestep embedding.
+ For usage, see UNet.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ use_new_attention_order=use_new_attention_order,
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+ self.pool = pool
+ if pool == "adaptive":
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
+ nn.Flatten(),
+ )
+ elif pool == "attention":
+ assert num_head_channels != -1
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ AttentionPool2d(
+ (image_size // ds), ch, num_head_channels, out_channels
+ ),
+ )
+ elif pool == "spatial":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ nn.ReLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ elif pool == "spatial_v2":
+ self.out = nn.Sequential(
+ nn.Linear(self._feature_size, 2048),
+ normalization(2048),
+ nn.SiLU(),
+ nn.Linear(2048, self.out_channels),
+ )
+ else:
+ raise NotImplementedError(f"Unexpected {pool} pooling")
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :return: an [N x K] Tensor of outputs.
+ """
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+
+ results = []
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = self.middle_block(h, emb)
+ if self.pool.startswith("spatial"):
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
+ h = th.cat(results, axis=-1)
+ return self.out(h)
+ else:
+ h = h.type(x.dtype)
+ return self.out(h)
diff --git a/audioldm/latent_diffusion/util.py b/audioldm/latent_diffusion/util.py
new file mode 100755
index 0000000000000000000000000000000000000000..8b289f6aa7f22a070870d8a706f944dc8547e936
--- /dev/null
+++ b/audioldm/latent_diffusion/util.py
@@ -0,0 +1,295 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from audioldm.utils import instantiate_from_config
+
+
+def make_beta_schedule(
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
+):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
+ )
+ ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(
+ linear_start, linear_end, n_timestep, dtype=torch.float64
+ )
+ elif schedule == "sqrt":
+ betas = (
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ ** 0.5
+ )
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(
+ ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
+):
+ if ddim_discr_method == "uniform":
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == "quad":
+ ddim_timesteps = (
+ (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
+ ).astype(int)
+ else:
+ raise NotImplementedError(
+ f'There is no ddim discretization method called "{ddim_discr_method}"'
+ )
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f"Selected timesteps for ddim sampler: {steps_out}")
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt(
+ (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
+ )
+ if verbose:
+ print(
+ f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
+ )
+ print(
+ f"For the chosen value of eta, which is {eta}, "
+ f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
+ )
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t).contiguous()
+ return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous()
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ else:
+ embedding = repeat(timesteps, "b -> b d", d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
+ shape[0], *((1,) * (len(shape) - 1))
+ )
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
diff --git a/audioldm/ldm.py b/audioldm/ldm.py
new file mode 100755
index 0000000000000000000000000000000000000000..5ad9ef16a2704ad6660a05d33d433a1e79593866
--- /dev/null
+++ b/audioldm/ldm.py
@@ -0,0 +1,824 @@
+import os
+
+import torch
+import numpy as np
+from tqdm import tqdm
+import os
+try:
+ from audioldm.utils import default, instantiate_from_config, save_wave
+ from audioldm.latent_diffusion.ddpm import DDPM
+ from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
+ from audioldm.latent_diffusion.util import noise_like
+ from audioldm.latent_diffusion.ddim import DDIMSampler
+except ModuleNotFoundError:
+ from .utils import default, instantiate_from_config, save_wave
+ from .latent_diffusion.ddpm import DDPM
+ from .variational_autoencoder.distributions import DiagonalGaussianDistribution
+ from .latent_diffusion.util import noise_like
+ from .latent_diffusion.ddim import DDIMSampler
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+
+ def __init__(
+ self,
+ device="cuda",
+ first_stage_config=None,
+ cond_stage_config=None,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ base_learning_rate=None,
+ *args,
+ **kwargs,
+ ):
+ self.device = device
+ self.learning_rate = base_learning_rate
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs["timesteps"]
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = "concat" if concat_mode else "crossattn"
+ if cond_stage_config == "__is_unconditional__":
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ self.cond_stage_key_orig = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer("scale_factor", torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+
+ def make_cond_schedule(
+ self,
+ ):
+ self.cond_ids = torch.full(
+ size=(self.num_timesteps,),
+ fill_value=self.num_timesteps - 1,
+ dtype=torch.long,
+ )
+ ids = torch.round(
+ torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
+ ).long()
+ self.cond_ids[: self.num_timesteps_cond] = ids
+
+ def register_schedule(
+ self,
+ given_betas=None,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ ):
+ super().register_schedule(
+ given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
+ )
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != "__is_first_stage__"
+ assert config != "__is_unconditional__"
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+ self.cond_stage_model = self.cond_stage_model.to(self.device)
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
+ )
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, "encode") and callable(
+ self.cond_stage_model.encode
+ ):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ # Text input is list
+ if type(c) == list and len(c) == 1:
+ c = self.cond_stage_model([c[0], c[0]])
+ c = c[0:1]
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ @torch.no_grad()
+ def get_input(
+ self,
+ batch,
+ k,
+ return_first_stage_encode=True,
+ return_first_stage_outputs=False,
+ force_c_encode=False,
+ cond_key=None,
+ return_original_cond=False,
+ bs=None,
+ ):
+ x = super().get_input(batch, k)
+
+ if bs is not None:
+ x = x[:bs]
+
+ x = x.to(self.device)
+
+ if return_first_stage_encode:
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ else:
+ z = None
+
+ if self.model.conditioning_key is not None:
+ if cond_key is None:
+ cond_key = self.cond_stage_key
+ if cond_key != self.first_stage_key:
+ if cond_key in ["caption", "coordinates_bbox"]:
+ xc = batch[cond_key]
+ elif cond_key == "class_label":
+ xc = batch
+ else:
+ # [bs, 1, 527]
+ xc = super().get_input(batch, cond_key)
+ if type(xc) == torch.Tensor:
+ xc = xc.to(self.device)
+ else:
+ xc = x
+ if not self.cond_stage_trainable or force_c_encode:
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ c = self.get_learned_conditioning(xc.to(self.device))
+ else:
+ c = xc
+
+ if bs is not None:
+ c = c[:bs]
+
+ else:
+ c = None
+ xc = None
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {"pos_x": pos_x, "pos_y": pos_y}
+ out = [z, c]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_original_cond:
+ out.append(xc)
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, "b h w c -> b c h w").contiguous()
+
+ z = 1.0 / self.scale_factor * z
+ return self.first_stage_model.decode(z)
+
+ def mel_spectrogram_to_waveform(self, mel):
+ # Mel: [bs, 1, t-steps, fbins]
+ if len(mel.size()) == 4:
+ mel = mel.squeeze(1)
+ mel = mel.permute(0, 2, 1)
+ waveform = self.first_stage_model.vocoder(mel)
+ waveform = waveform.cpu().detach().numpy()
+ return waveform
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ return self.first_stage_model.encode(x)
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+
+ if isinstance(cond, dict):
+ # hybrid case, cond is exptected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ if self.model.conditioning_key == "concat":
+ key = "c_concat"
+ elif self.model.conditioning_key == "crossattn":
+ key = "c_crossattn"
+ else:
+ key = "c_film"
+
+ cond = {key: cond}
+
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def p_mean_variance(
+ self,
+ x,
+ c,
+ t,
+ clip_denoised: bool,
+ return_codebook_ids=False,
+ quantize_denoised=False,
+ return_x0=False,
+ score_corrector=None,
+ corrector_kwargs=None,
+ ):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(
+ self, model_out, x, t, c, **corrector_kwargs
+ )
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1.0, 1.0)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+ x_start=x_recon, x_t=x, t=t
+ )
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(
+ self,
+ x,
+ c,
+ t,
+ clip_denoised=False,
+ repeat_noise=False,
+ return_codebook_ids=False,
+ quantize_denoised=False,
+ return_x0=False,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ ):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(
+ x=x,
+ c=c,
+ t=t,
+ clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ )
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.0:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (
+ (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
+ )
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (
+ 0.5 * model_log_variance
+ ).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return (
+ model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
+ x0,
+ )
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(
+ self,
+ cond,
+ shape,
+ verbose=True,
+ callback=None,
+ quantize_denoised=False,
+ img_callback=None,
+ mask=None,
+ x0=None,
+ temperature=1.0,
+ noise_dropout=0.0,
+ score_corrector=None,
+ corrector_kwargs=None,
+ batch_size=None,
+ x_T=None,
+ start_T=None,
+ log_every_t=None,
+ ):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {
+ key: cond[key][:batch_size]
+ if not isinstance(cond[key], list)
+ else list(map(lambda x: x[:batch_size], cond[key]))
+ for key in cond
+ }
+ else:
+ cond = (
+ [c[:batch_size] for c in cond]
+ if isinstance(cond, list)
+ else cond[:batch_size]
+ )
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = (
+ tqdm(
+ reversed(range(0, timesteps)),
+ desc="Progressive Generation",
+ total=timesteps,
+ )
+ if verbose
+ else reversed(range(0, timesteps))
+ )
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != "hybrid"
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(
+ img,
+ cond,
+ ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised,
+ return_x0=True,
+ temperature=temperature[i],
+ noise_dropout=noise_dropout,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ )
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1.0 - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(
+ self,
+ cond,
+ shape,
+ return_intermediates=False,
+ x_T=None,
+ verbose=True,
+ callback=None,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ img_callback=None,
+ start_T=None,
+ log_every_t=None,
+ ):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = (
+ tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps)
+ if verbose
+ else reversed(range(0, timesteps))
+ )
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != "hybrid"
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(
+ img,
+ cond,
+ ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised,
+ )
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1.0 - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(
+ self,
+ cond,
+ batch_size=16,
+ return_intermediates=False,
+ x_T=None,
+ verbose=True,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ shape=None,
+ **kwargs,
+ ):
+ if shape is None:
+ shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {
+ key: cond[key][:batch_size]
+ if not isinstance(cond[key], list)
+ else list(map(lambda x: x[:batch_size], cond[key]))
+ for key in cond
+ }
+ else:
+ cond = (
+ [c[:batch_size] for c in cond]
+ if isinstance(cond, list)
+ else cond[:batch_size]
+ )
+ return self.p_sample_loop(
+ cond,
+ shape,
+ return_intermediates=return_intermediates,
+ x_T=x_T,
+ verbose=verbose,
+ timesteps=timesteps,
+ quantize_denoised=quantize_denoised,
+ mask=mask,
+ x0=x0,
+ **kwargs,
+ )
+
+ @torch.no_grad()
+ def sample_log(
+ self,
+ cond,
+ batch_size,
+ ddim,
+ ddim_steps,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ use_plms=False,
+ mask=None,
+ **kwargs,
+ ):
+
+ if mask is not None:
+ shape = (self.channels, mask.size()[-2], mask.size()[-1])
+ else:
+ shape = (self.channels, self.latent_t_size, self.latent_f_size)
+
+ intermediate = None
+ if ddim and not use_plms:
+ # print("Use ddim sampler")
+
+ ddim_sampler = DDIMSampler(self)
+ samples, intermediates = ddim_sampler.sample(
+ ddim_steps,
+ batch_size,
+ shape,
+ cond,
+ verbose=False,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ mask=mask,
+ **kwargs,
+ )
+
+ else:
+ # print("Use DDPM sampler")
+ samples, intermediates = self.sample(
+ cond=cond,
+ batch_size=batch_size,
+ return_intermediates=True,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ mask=mask,
+ unconditional_conditioning=unconditional_conditioning,
+ **kwargs,
+ )
+
+ return samples, intermediate
+
+ @torch.no_grad()
+ def generate_sample(
+ self,
+ batchs,
+ ddim_steps=200,
+ ddim_eta=1.0,
+ x_T=None,
+ n_candidate_gen_per_text=1,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ name="waveform",
+ use_plms=False,
+ save=False,
+ **kwargs,
+ ):
+ # Generate n_candidate_gen_per_text times and select the best
+ # Batch: audio, text, fnames
+ assert x_T is None
+ try:
+ batchs = iter(batchs)
+ except TypeError:
+ raise ValueError("The first input argument should be an iterable object")
+
+ if use_plms:
+ assert ddim_steps is not None
+ use_ddim = ddim_steps is not None
+ # waveform_save_path = os.path.join(self.get_log_dir(), name)
+ # os.makedirs(waveform_save_path, exist_ok=True)
+ # print("Waveform save path: ", waveform_save_path)
+
+ with self.ema_scope("Generate"):
+ for batch in batchs:
+ z, c = self.get_input(
+ batch,
+ self.first_stage_key,
+ cond_key=self.cond_stage_key,
+ return_first_stage_outputs=False,
+ force_c_encode=True,
+ return_original_cond=False,
+ bs=None,
+ )
+ text = super().get_input(batch, "text")
+
+ # Generate multiple samples
+ batch_size = z.shape[0] * n_candidate_gen_per_text
+ c = torch.cat([c] * n_candidate_gen_per_text, dim=0)
+ text = text * n_candidate_gen_per_text
+
+ if unconditional_guidance_scale != 1.0:
+ unconditional_conditioning = (
+ self.cond_stage_model.get_unconditional_condition(batch_size)
+ )
+
+ samples, _ = self.sample_log(
+ cond=c,
+ batch_size=batch_size,
+ x_T=x_T,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ use_plms=use_plms,
+ )
+
+ if(torch.max(torch.abs(samples)) > 1e2):
+ samples = torch.clip(samples, min=-10, max=10)
+
+ mel = self.decode_first_stage(samples)
+
+ waveform = self.mel_spectrogram_to_waveform(mel)
+
+ if waveform.shape[0] > 1:
+ similarity = self.cond_stage_model.cos_similarity(
+ torch.FloatTensor(waveform).squeeze(1), text
+ )
+
+ best_index = []
+ for i in range(z.shape[0]):
+ candidates = similarity[i :: z.shape[0]]
+ max_index = torch.argmax(candidates).item()
+ best_index.append(i + max_index * z.shape[0])
+
+ waveform = waveform[best_index]
+ # print("Similarity between generated audio and text", similarity)
+ # print("Choose the following indexes:", best_index)
+
+ return waveform
+
+ @torch.no_grad()
+ def generate_sample_masked(
+ self,
+ batchs,
+ ddim_steps=200,
+ ddim_eta=1.0,
+ x_T=None,
+ n_candidate_gen_per_text=1,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ name="waveform",
+ use_plms=False,
+ time_mask_ratio_start_and_end=(0.25, 0.75),
+ freq_mask_ratio_start_and_end=(0.75, 1.0),
+ save=False,
+ **kwargs,
+ ):
+ # Generate n_candidate_gen_per_text times and select the best
+ # Batch: audio, text, fnames
+ assert x_T is None
+ try:
+ batchs = iter(batchs)
+ except TypeError:
+ raise ValueError("The first input argument should be an iterable object")
+
+ if use_plms:
+ assert ddim_steps is not None
+ use_ddim = ddim_steps is not None
+ # waveform_save_path = os.path.join(self.get_log_dir(), name)
+ # os.makedirs(waveform_save_path, exist_ok=True)
+ # print("Waveform save path: ", waveform_save_path)
+
+ with self.ema_scope("Generate"):
+ for batch in batchs:
+ z, c = self.get_input(
+ batch,
+ self.first_stage_key,
+ cond_key=self.cond_stage_key,
+ return_first_stage_outputs=False,
+ force_c_encode=True,
+ return_original_cond=False,
+ bs=None,
+ )
+ text = super().get_input(batch, "text")
+
+ # Generate multiple samples
+ batch_size = z.shape[0] * n_candidate_gen_per_text
+
+ _, h, w = z.shape[0], z.shape[2], z.shape[3]
+
+ mask = torch.ones(batch_size, h, w).to(self.device)
+
+ mask[:, int(h * time_mask_ratio_start_and_end[0]) : int(h * time_mask_ratio_start_and_end[1]), :] = 0
+ mask[:, :, int(w * freq_mask_ratio_start_and_end[0]) : int(w * freq_mask_ratio_start_and_end[1])] = 0
+ mask = mask[:, None, ...]
+
+ c = torch.cat([c] * n_candidate_gen_per_text, dim=0)
+ text = text * n_candidate_gen_per_text
+
+ if unconditional_guidance_scale != 1.0:
+ unconditional_conditioning = (
+ self.cond_stage_model.get_unconditional_condition(batch_size)
+ )
+
+ samples, _ = self.sample_log(
+ cond=c,
+ batch_size=batch_size,
+ x_T=x_T,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ use_plms=use_plms, mask=mask, x0=torch.cat([z] * n_candidate_gen_per_text)
+ )
+
+ mel = self.decode_first_stage(samples)
+
+ waveform = self.mel_spectrogram_to_waveform(mel)
+
+ if waveform.shape[0] > 1:
+ similarity = self.cond_stage_model.cos_similarity(
+ torch.FloatTensor(waveform).squeeze(1), text
+ )
+
+ best_index = []
+ for i in range(z.shape[0]):
+ candidates = similarity[i :: z.shape[0]]
+ max_index = torch.argmax(candidates).item()
+ best_index.append(i + max_index * z.shape[0])
+
+ waveform = waveform[best_index]
+ # print("Similarity between generated audio and text", similarity)
+ # print("Choose the following indexes:", best_index)
+
+ return waveform
\ No newline at end of file
diff --git a/audioldm/pipeline.py b/audioldm/pipeline.py
new file mode 100755
index 0000000000000000000000000000000000000000..b08e1f77206483025ce027588c2dea1de78ae26c
--- /dev/null
+++ b/audioldm/pipeline.py
@@ -0,0 +1,301 @@
+import os
+
+import argparse
+import yaml
+import torch
+from torch import autocast
+from tqdm import tqdm, trange
+
+from audioldm import LatentDiffusion, seed_everything
+from audioldm.utils import default_audioldm_config, get_duration, get_bit_depth, get_metadata, download_checkpoint
+from audioldm.audio import wav_to_fbank, TacotronSTFT, read_wav_file
+from audioldm.latent_diffusion.ddim import DDIMSampler
+from einops import repeat
+import os
+
+def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1):
+ text = [text] * batchsize
+ if batchsize < 1:
+ print("Warning: Batchsize must be at least 1. Batchsize is set to .")
+
+ if(fbank is None):
+ fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format
+ else:
+ fbank = torch.FloatTensor(fbank)
+ fbank = fbank.expand(batchsize, 1024, 64)
+ assert fbank.size(0) == batchsize
+
+ stft = torch.zeros((batchsize, 1024, 512)) # Not used
+
+ if(waveform is None):
+ waveform = torch.zeros((batchsize, 160000)) # Not used
+ else:
+ waveform = torch.FloatTensor(waveform)
+ waveform = waveform.expand(batchsize, -1)
+ assert waveform.size(0) == batchsize
+
+ fname = [""] * batchsize # Not used
+
+ batch = (
+ fbank,
+ stft,
+ None,
+ fname,
+ waveform,
+ text,
+ )
+ return batch
+
+def round_up_duration(duration):
+ return int(round(duration/2.5) + 1) * 2.5
+
+def build_model(
+ ckpt_path=None,
+ config=None,
+ model_name="audioldm-s-full"
+):
+ print("Load AudioLDM: %s", model_name)
+
+ if(ckpt_path is None):
+ ckpt_path = get_metadata()[model_name]["path"]
+
+ if(not os.path.exists(ckpt_path)):
+ download_checkpoint(model_name)
+
+ if torch.cuda.is_available():
+ device = torch.device("cuda:0")
+ else:
+ device = torch.device("cpu")
+
+ if config is not None:
+ assert type(config) is str
+ config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
+ else:
+ config = default_audioldm_config(model_name)
+
+ # Use text as condition instead of using waveform during training
+ config["model"]["params"]["device"] = device
+ config["model"]["params"]["cond_stage_key"] = "text"
+
+ # No normalization here
+ latent_diffusion = LatentDiffusion(**config["model"]["params"])
+
+ resume_from_checkpoint = ckpt_path
+
+ checkpoint = torch.load(resume_from_checkpoint, map_location=device)
+ latent_diffusion.load_state_dict(checkpoint["state_dict"])
+
+ latent_diffusion.eval()
+ latent_diffusion = latent_diffusion.to(device)
+
+ latent_diffusion.cond_stage_model.embed_mode = "text"
+ return latent_diffusion
+
+def duration_to_latent_t_size(duration):
+ return int(duration * 25.6)
+
+def set_cond_audio(latent_diffusion):
+ latent_diffusion.cond_stage_key = "waveform"
+ latent_diffusion.cond_stage_model.embed_mode="audio"
+ return latent_diffusion
+
+def set_cond_text(latent_diffusion):
+ latent_diffusion.cond_stage_key = "text"
+ latent_diffusion.cond_stage_model.embed_mode="text"
+ return latent_diffusion
+
+def text_to_audio(
+ latent_diffusion,
+ text,
+ original_audio_file_path = None,
+ seed=42,
+ ddim_steps=200,
+ duration=10,
+ batchsize=1,
+ guidance_scale=2.5,
+ n_candidate_gen_per_text=3,
+ config=None,
+):
+ seed_everything(int(seed))
+ waveform = None
+ if(original_audio_file_path is not None):
+ waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160)
+
+ batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize)
+
+ latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
+
+ if(waveform is not None):
+ print("Generate audio that has similar content as %s" % original_audio_file_path)
+ latent_diffusion = set_cond_audio(latent_diffusion)
+ else:
+ print("Generate audio using text %s" % text)
+ latent_diffusion = set_cond_text(latent_diffusion)
+
+ with torch.no_grad():
+ waveform = latent_diffusion.generate_sample(
+ [batch],
+ unconditional_guidance_scale=guidance_scale,
+ ddim_steps=ddim_steps,
+ n_candidate_gen_per_text=n_candidate_gen_per_text,
+ duration=duration,
+ )
+ return waveform
+
+def style_transfer(
+ latent_diffusion,
+ text,
+ original_audio_file_path,
+ transfer_strength,
+ seed=42,
+ duration=10,
+ batchsize=1,
+ guidance_scale=2.5,
+ ddim_steps=200,
+ config=None,
+):
+ if torch.cuda.is_available():
+ device = torch.device("cuda:0")
+ else:
+ device = torch.device("cpu")
+
+ assert original_audio_file_path is not None, "You need to provide the original audio file path"
+
+ audio_file_duration = get_duration(original_audio_file_path)
+
+ assert get_bit_depth(original_audio_file_path) == 16, "The bit depth of the original audio file %s must be 16" % original_audio_file_path
+
+ # if(duration > 20):
+ # print("Warning: The duration of the audio file %s must be less than 20 seconds. Longer duration will result in Nan in model output (we are still debugging that); Automatically set duration to 20 seconds")
+ # duration = 20
+
+ if(duration >= audio_file_duration):
+ print("Warning: Duration you specified %s-seconds must equal or smaller than the audio file duration %ss" % (duration, audio_file_duration))
+ duration = round_up_duration(audio_file_duration)
+ print("Set new duration as %s-seconds" % duration)
+
+ # duration = round_up_duration(duration)
+
+ latent_diffusion = set_cond_text(latent_diffusion)
+
+ if config is not None:
+ assert type(config) is str
+ config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
+ else:
+ config = default_audioldm_config()
+
+ seed_everything(int(seed))
+ # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
+ latent_diffusion.cond_stage_model.embed_mode = "text"
+
+ fn_STFT = TacotronSTFT(
+ config["preprocessing"]["stft"]["filter_length"],
+ config["preprocessing"]["stft"]["hop_length"],
+ config["preprocessing"]["stft"]["win_length"],
+ config["preprocessing"]["mel"]["n_mel_channels"],
+ config["preprocessing"]["audio"]["sampling_rate"],
+ config["preprocessing"]["mel"]["mel_fmin"],
+ config["preprocessing"]["mel"]["mel_fmax"],
+ )
+
+ mel, _, _ = wav_to_fbank(
+ original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT
+ )
+ mel = mel.unsqueeze(0).unsqueeze(0).to(device)
+ mel = repeat(mel, "1 ... -> b ...", b=batchsize)
+ init_latent = latent_diffusion.get_first_stage_encoding(
+ latent_diffusion.encode_first_stage(mel)
+ ) # move to latent space, encode and sample
+ if(torch.max(torch.abs(init_latent)) > 1e2):
+ init_latent = torch.clip(init_latent, min=-10, max=10)
+ sampler = DDIMSampler(latent_diffusion)
+ sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False)
+
+ t_enc = int(transfer_strength * ddim_steps)
+ prompts = text
+
+ with torch.no_grad():
+ with autocast("cuda"):
+ with latent_diffusion.ema_scope():
+ uc = None
+ if guidance_scale != 1.0:
+ uc = latent_diffusion.cond_stage_model.get_unconditional_condition(
+ batchsize
+ )
+
+ c = latent_diffusion.get_learned_conditioning([prompts] * batchsize)
+ z_enc = sampler.stochastic_encode(
+ init_latent, torch.tensor([t_enc] * batchsize).to(device)
+ )
+ samples = sampler.decode(
+ z_enc,
+ c,
+ t_enc,
+ unconditional_guidance_scale=guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ # x_samples = latent_diffusion.decode_first_stage(samples) # Will result in Nan in output
+ # print(torch.sum(torch.isnan(samples)))
+ x_samples = latent_diffusion.decode_first_stage(samples)
+ # print(x_samples)
+ x_samples = latent_diffusion.decode_first_stage(samples[:,:,:-3,:])
+ # print(x_samples)
+ waveform = latent_diffusion.first_stage_model.decode_to_waveform(
+ x_samples
+ )
+
+ return waveform
+
+def super_resolution_and_inpainting(
+ latent_diffusion,
+ text,
+ original_audio_file_path = None,
+ seed=42,
+ ddim_steps=200,
+ duration=None,
+ batchsize=1,
+ guidance_scale=2.5,
+ n_candidate_gen_per_text=3,
+ time_mask_ratio_start_and_end=(0.10, 0.15), # regenerate the 10% to 15% of the time steps in the spectrogram
+ # time_mask_ratio_start_and_end=(1.0, 1.0), # no inpainting
+ # freq_mask_ratio_start_and_end=(0.75, 1.0), # regenerate the higher 75% to 100% mel bins
+ freq_mask_ratio_start_and_end=(1.0, 1.0), # no super-resolution
+ config=None,
+):
+ seed_everything(int(seed))
+ if config is not None:
+ assert type(config) is str
+ config = yaml.load(open(config, "r"), Loader=yaml.FullLoader)
+ else:
+ config = default_audioldm_config()
+ fn_STFT = TacotronSTFT(
+ config["preprocessing"]["stft"]["filter_length"],
+ config["preprocessing"]["stft"]["hop_length"],
+ config["preprocessing"]["stft"]["win_length"],
+ config["preprocessing"]["mel"]["n_mel_channels"],
+ config["preprocessing"]["audio"]["sampling_rate"],
+ config["preprocessing"]["mel"]["mel_fmin"],
+ config["preprocessing"]["mel"]["mel_fmax"],
+ )
+
+ # waveform = read_wav_file(original_audio_file_path, None)
+ mel, _, _ = wav_to_fbank(
+ original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT
+ )
+
+ batch = make_batch_for_text_to_audio(text, fbank=mel[None,...], batchsize=batchsize)
+
+ # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration)
+ latent_diffusion = set_cond_text(latent_diffusion)
+
+ with torch.no_grad():
+ waveform = latent_diffusion.generate_sample_masked(
+ [batch],
+ unconditional_guidance_scale=guidance_scale,
+ ddim_steps=ddim_steps,
+ n_candidate_gen_per_text=n_candidate_gen_per_text,
+ duration=duration,
+ time_mask_ratio_start_and_end=time_mask_ratio_start_and_end,
+ freq_mask_ratio_start_and_end=freq_mask_ratio_start_and_end
+ )
+ return waveform
\ No newline at end of file
diff --git a/audioldm/utils.py b/audioldm/utils.py
new file mode 100755
index 0000000000000000000000000000000000000000..5401b29d4366774233f1bf4a9e7fcb7ce214187e
--- /dev/null
+++ b/audioldm/utils.py
@@ -0,0 +1,281 @@
+import contextlib
+import importlib
+
+from inspect import isfunction
+import os
+import soundfile as sf
+import time
+import wave
+
+import urllib.request
+import progressbar
+
+CACHE_DIR = os.getenv(
+ "AUDIOLDM_CACHE_DIR",
+ os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
+
+def get_duration(fname):
+ with contextlib.closing(wave.open(fname, 'r')) as f:
+ frames = f.getnframes()
+ rate = f.getframerate()
+ return frames / float(rate)
+
+def get_bit_depth(fname):
+ with contextlib.closing(wave.open(fname, 'r')) as f:
+ bit_depth = f.getsampwidth() * 8
+ return bit_depth
+
+def get_time():
+ t = time.localtime()
+ return time.strftime("%d_%m_%Y_%H_%M_%S", t)
+
+def seed_everything(seed):
+ import random, os
+ import numpy as np
+ import torch
+
+ random.seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = True
+
+
+def save_wave(waveform, savepath, name="outwav"):
+ if type(name) is not list:
+ name = [name] * waveform.shape[0]
+
+ for i in range(waveform.shape[0]):
+ path = os.path.join(
+ savepath,
+ "%s_%s.wav"
+ % (
+ os.path.basename(name[i])
+ if (not ".wav" in name[i])
+ else os.path.basename(name[i]).split(".")[0],
+ i,
+ ),
+ )
+ print("Save audio to %s" % path)
+ sf.write(path, waveform[i, 0], samplerate=16000)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == "__is_first_stage__":
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def default_audioldm_config(model_name="audioldm-s-full"):
+ basic_config = {
+ "wave_file_save_path": "./output",
+ "id": {
+ "version": "v1",
+ "name": "default",
+ "root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml",
+ },
+ "preprocessing": {
+ "audio": {"sampling_rate": 16000, "max_wav_value": 32768},
+ "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024},
+ "mel": {
+ "n_mel_channels": 64,
+ "mel_fmin": 0,
+ "mel_fmax": 8000,
+ "freqm": 0,
+ "timem": 0,
+ "blur": False,
+ "mean": -4.63,
+ "std": 2.74,
+ "target_length": 1024,
+ },
+ },
+ "model": {
+ "device": "cuda",
+ "target": "audioldm.pipline.LatentDiffusion",
+ "params": {
+ "base_learning_rate": 5e-06,
+ "linear_start": 0.0015,
+ "linear_end": 0.0195,
+ "num_timesteps_cond": 1,
+ "log_every_t": 200,
+ "timesteps": 1000,
+ "first_stage_key": "fbank",
+ "cond_stage_key": "waveform",
+ "latent_t_size": 256,
+ "latent_f_size": 16,
+ "channels": 8,
+ "cond_stage_trainable": True,
+ "conditioning_key": "film",
+ "monitor": "val/loss_simple_ema",
+ "scale_by_std": True,
+ "unet_config": {
+ "target": "audioldm.latent_diffusion.openaimodel.UNetModel",
+ "params": {
+ "image_size": 64,
+ "extra_film_condition_dim": 512,
+ "extra_film_use_concat": True,
+ "in_channels": 8,
+ "out_channels": 8,
+ "model_channels": 128,
+ "attention_resolutions": [8, 4, 2],
+ "num_res_blocks": 2,
+ "channel_mult": [1, 2, 3, 5],
+ "num_head_channels": 32,
+ "use_spatial_transformer": True,
+ },
+ },
+ "first_stage_config": {
+ "base_learning_rate": 4.5e-05,
+ "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL",
+ "params": {
+ "monitor": "val/rec_loss",
+ "image_key": "fbank",
+ "subband": 1,
+ "embed_dim": 8,
+ "time_shuffle": 1,
+ "ddconfig": {
+ "double_z": True,
+ "z_channels": 8,
+ "resolution": 256,
+ "downsample_time": False,
+ "in_channels": 1,
+ "out_ch": 1,
+ "ch": 128,
+ "ch_mult": [1, 2, 4],
+ "num_res_blocks": 2,
+ "attn_resolutions": [],
+ "dropout": 0.0,
+ },
+ },
+ },
+ "cond_stage_config": {
+ "target": "audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2",
+ "params": {
+ "key": "waveform",
+ "sampling_rate": 16000,
+ "embed_mode": "audio",
+ "unconditional_prob": 0.1,
+ },
+ },
+ },
+ },
+ }
+
+ if("-l-" in model_name):
+ basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 256
+ basic_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = 64
+ elif("-m-" in model_name):
+ basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 192
+ basic_config["model"]["params"]["cond_stage_config"]["params"]["amodel"] = "HTSAT-base" # This model use a larger HTAST
+
+ return basic_config
+
+def get_metadata():
+ return {
+ "audioldm-s-full": {
+ "path": os.path.join(
+ CACHE_DIR,
+ "audioldm-s-full.ckpt",
+ ),
+ "url": "https://zenodo.org/record/7600541/files/audioldm-s-full?download=1",
+ },
+ "audioldm-l-full": {
+ "path": os.path.join(
+ CACHE_DIR,
+ "audioldm-l-full.ckpt",
+ ),
+ "url": "https://zenodo.org/record/7698295/files/audioldm-full-l.ckpt?download=1",
+ },
+ "audioldm-s-full-v2": {
+ "path": os.path.join(
+ CACHE_DIR,
+ "audioldm-s-full-v2.ckpt",
+ ),
+ "url": "https://zenodo.org/record/7698295/files/audioldm-full-s-v2.ckpt?download=1",
+ },
+ "audioldm-m-text-ft": {
+ "path": os.path.join(
+ CACHE_DIR,
+ "audioldm-m-text-ft.ckpt",
+ ),
+ "url": "https://zenodo.org/record/7813012/files/audioldm-m-text-ft.ckpt?download=1",
+ },
+ "audioldm-s-text-ft": {
+ "path": os.path.join(
+ CACHE_DIR,
+ "audioldm-s-text-ft.ckpt",
+ ),
+ "url": "https://zenodo.org/record/7813012/files/audioldm-s-text-ft.ckpt?download=1",
+ },
+ "audioldm-m-full": {
+ "path": os.path.join(
+ CACHE_DIR,
+ "audioldm-m-full.ckpt",
+ ),
+ "url": "https://zenodo.org/record/7813012/files/audioldm-m-full.ckpt?download=1",
+ },
+ }
+
+class MyProgressBar():
+ def __init__(self):
+ self.pbar = None
+
+ def __call__(self, block_num, block_size, total_size):
+ if not self.pbar:
+ self.pbar=progressbar.ProgressBar(maxval=total_size)
+ self.pbar.start()
+
+ downloaded = block_num * block_size
+ if downloaded < total_size:
+ self.pbar.update(downloaded)
+ else:
+ self.pbar.finish()
+
+def download_checkpoint(checkpoint_name="audioldm-s-full"):
+ meta = get_metadata()
+ if(checkpoint_name not in meta.keys()):
+ print("The model name you provided is not supported. Please use one of the following: ", meta.keys())
+
+ if not os.path.exists(meta[checkpoint_name]["path"]) or os.path.getsize(meta[checkpoint_name]["path"]) < 2*10**9:
+ os.makedirs(os.path.dirname(meta[checkpoint_name]["path"]), exist_ok=True)
+ print(f"Downloading the main structure of {checkpoint_name} into {os.path.dirname(meta[checkpoint_name]['path'])}")
+
+ urllib.request.urlretrieve(meta[checkpoint_name]["url"], meta[checkpoint_name]["path"], MyProgressBar())
+ print(
+ "Weights downloaded in: {} Size: {}".format(
+ meta[checkpoint_name]["path"],
+ os.path.getsize(meta[checkpoint_name]["path"]),
+ )
+ )
+
\ No newline at end of file
diff --git a/audioldm/variational_autoencoder/__init__.py b/audioldm/variational_autoencoder/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/audioldm/variational_autoencoder/autoencoder.py b/audioldm/variational_autoencoder/autoencoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..9cf409edff3d6440e076f169cac10f6d0c157a95
--- /dev/null
+++ b/audioldm/variational_autoencoder/autoencoder.py
@@ -0,0 +1,103 @@
+import torch
+from audioldm.latent_diffusion.ema import *
+from audioldm.variational_autoencoder.modules import Encoder, Decoder
+from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution
+
+from audioldm.hifigan.utilities import get_vocoder, vocoder_infer
+
+
+class AutoencoderKL(nn.Module):
+ def __init__(
+ self,
+ ddconfig=None,
+ lossconfig=None,
+ image_key="fbank",
+ embed_dim=None,
+ time_shuffle=1,
+ subband=1,
+ ckpt_path=None,
+ reload_from_ckpt=None,
+ ignore_keys=[],
+ colorize_nlabels=None,
+ monitor=None,
+ base_learning_rate=1e-5,
+ ):
+ super().__init__()
+
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+
+ self.subband = int(subband)
+
+ if self.subband > 1:
+ print("Use subband decomposition %s" % self.subband)
+
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+
+ self.vocoder = get_vocoder(None, "cpu")
+ self.embed_dim = embed_dim
+
+ if monitor is not None:
+ self.monitor = monitor
+
+ self.time_shuffle = time_shuffle
+ self.reload_from_ckpt = reload_from_ckpt
+ self.reloaded = False
+ self.mean, self.std = None, None
+
+ def encode(self, x):
+ # x = self.time_shuffle_operation(x)
+ x = self.freq_split_subband(x)
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ dec = self.freq_merge_subband(dec)
+ return dec
+
+ def decode_to_waveform(self, dec):
+ dec = dec.squeeze(1).permute(0, 2, 1)
+ wav_reconstruction = vocoder_infer(dec, self.vocoder)
+ return wav_reconstruction
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+
+ if self.flag_first_run:
+ print("Latent size: ", z.size())
+ self.flag_first_run = False
+
+ dec = self.decode(z)
+
+ return dec, posterior
+
+ def freq_split_subband(self, fbank):
+ if self.subband == 1 or self.image_key != "stft":
+ return fbank
+
+ bs, ch, tstep, fbins = fbank.size()
+
+ assert fbank.size(-1) % self.subband == 0
+ assert ch == 1
+
+ return (
+ fbank.squeeze(1)
+ .reshape(bs, tstep, self.subband, fbins // self.subband)
+ .permute(0, 2, 1, 3)
+ )
+
+ def freq_merge_subband(self, subband_fbank):
+ if self.subband == 1 or self.image_key != "stft":
+ return subband_fbank
+ assert subband_fbank.size(1) == self.subband # Channel dimension
+ bs, sub_ch, tstep, fbins = subband_fbank.size()
+ return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)
diff --git a/audioldm/variational_autoencoder/distributions.py b/audioldm/variational_autoencoder/distributions.py
new file mode 100755
index 0000000000000000000000000000000000000000..58eb535e7769f402169ddff77ee45c96ba3650d9
--- /dev/null
+++ b/audioldm/variational_autoencoder/distributions.py
@@ -0,0 +1,102 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(
+ device=self.parameters.device
+ )
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
+ device=self.parameters.device
+ )
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.mean(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.mean(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/audioldm/variational_autoencoder/modules.py b/audioldm/variational_autoencoder/modules.py
new file mode 100755
index 0000000000000000000000000000000000000000..e48386d045c1d0e159de33db02af1035159c3447
--- /dev/null
+++ b/audioldm/variational_autoencoder/modules.py
@@ -0,0 +1,1066 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+from audioldm.utils import instantiate_from_config
+from audioldm.latent_diffusion.attention import LinearAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class UpsampleTimeStride4(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=5, stride=1, padding=2
+ )
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # Do time downsampling here
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
+ )
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class DownsampleTimeStride4(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # Do time downsampling here
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
+ )
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w).contiguous()
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w).contiguous()
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(
+ v, w_
+ ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w).contiguous()
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ use_timestep=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch * 4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList(
+ [
+ torch.nn.Linear(self.ch, self.temb_ch),
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
+ ]
+ )
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(
+ ResnetBlock(
+ in_channels=block_in + skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x, t=None, context=None):
+ # assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb
+ )
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ downsample_time_stride4_levels=[],
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
+
+ if len(self.downsample_time_stride4_levels) > 0:
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
+ "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
+ % str(self.num_resolutions)
+ )
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ if i_level in self.downsample_time_stride4_levels:
+ down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
+ else:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ downsample_time_stride4_levels=[],
+ attn_type="vanilla",
+ **ignorekwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
+
+ if len(self.downsample_time_stride4_levels) > 0:
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
+ "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
+ % str(self.num_resolutions)
+ )
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ # print("Working with z of shape {} = {} dimensions.".format(
+ # self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
+ )
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ if i_level - 1 in self.downsample_time_stride4_levels:
+ up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
+ else:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, z):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList(
+ [
+ nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(
+ in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0,
+ dropout=0.0,
+ ),
+ ResnetBlock(
+ in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0,
+ dropout=0.0,
+ ),
+ ResnetBlock(
+ in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0,
+ dropout=0.0,
+ ),
+ nn.Conv2d(2 * in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True),
+ ]
+ )
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1, 2, 3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ ch,
+ num_res_blocks,
+ resolution,
+ ch_mult=(2, 2),
+ dropout=0.0,
+ ):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
+ )
+ self.res_block1 = nn.ModuleList(
+ [
+ ResnetBlock(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0,
+ )
+ for _ in range(depth)
+ ]
+ )
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList(
+ [
+ ResnetBlock(
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0,
+ )
+ for _ in range(depth)
+ ]
+ )
+
+ self.conv_out = nn.Conv2d(
+ mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(
+ x,
+ size=(
+ int(round(x.shape[2] * self.factor)),
+ int(round(x.shape[3] * self.factor)),
+ ),
+ )
+ x = self.attn(x).contiguous()
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ ch,
+ resolution,
+ out_ch,
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ ch_mult=(1, 2, 4, 8),
+ rescale_factor=1.0,
+ rescale_module_depth=1,
+ ):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ num_res_blocks=num_res_blocks,
+ ch=ch,
+ ch_mult=ch_mult,
+ z_channels=intermediate_chn,
+ double_z=False,
+ resolution=resolution,
+ attn_resolutions=attn_resolutions,
+ dropout=dropout,
+ resamp_with_conv=resamp_with_conv,
+ out_ch=None,
+ )
+ self.rescaler = LatentRescaler(
+ factor=rescale_factor,
+ in_channels=intermediate_chn,
+ mid_channels=intermediate_chn,
+ out_channels=out_ch,
+ depth=rescale_module_depth,
+ )
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(
+ self,
+ z_channels,
+ out_ch,
+ resolution,
+ num_res_blocks,
+ attn_resolutions,
+ ch,
+ ch_mult=(1, 2, 4, 8),
+ dropout=0.0,
+ resamp_with_conv=True,
+ rescale_factor=1.0,
+ rescale_module_depth=1,
+ ):
+ super().__init__()
+ tmp_chn = z_channels * ch_mult[-1]
+ self.decoder = Decoder(
+ out_ch=out_ch,
+ z_channels=tmp_chn,
+ attn_resolutions=attn_resolutions,
+ dropout=dropout,
+ resamp_with_conv=resamp_with_conv,
+ in_channels=None,
+ num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult,
+ resolution=resolution,
+ ch=ch,
+ )
+ self.rescaler = LatentRescaler(
+ factor=rescale_factor,
+ in_channels=z_channels,
+ mid_channels=tmp_chn,
+ out_channels=tmp_chn,
+ depth=rescale_module_depth,
+ )
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size // in_size)) + 1
+ factor_up = 1.0 + (out_size % in_size)
+ print(
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
+ )
+ self.rescaler = LatentRescaler(
+ factor=factor_up,
+ in_channels=in_channels,
+ mid_channels=2 * in_channels,
+ out_channels=in_channels,
+ )
+ self.decoder = Decoder(
+ out_ch=out_channels,
+ resolution=out_size,
+ z_channels=in_channels,
+ num_res_blocks=2,
+ attn_resolutions=[],
+ in_channels=None,
+ ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)],
+ )
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
+ )
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
+ )
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor == 1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
+ )
+ return x
+
+
+class FirstStagePostProcessor(nn.Module):
+ def __init__(
+ self,
+ ch_mult: list,
+ in_channels,
+ pretrained_model: nn.Module = None,
+ reshape=False,
+ n_channels=None,
+ dropout=0.0,
+ pretrained_config=None,
+ ):
+ super().__init__()
+ if pretrained_config is None:
+ assert (
+ pretrained_model is not None
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.pretrained_model = pretrained_model
+ else:
+ assert (
+ pretrained_config is not None
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.instantiate_pretrained(pretrained_config)
+
+ self.do_reshape = reshape
+
+ if n_channels is None:
+ n_channels = self.pretrained_model.encoder.ch
+
+ self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
+ self.proj = nn.Conv2d(
+ in_channels, n_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ blocks = []
+ downs = []
+ ch_in = n_channels
+ for m in ch_mult:
+ blocks.append(
+ ResnetBlock(
+ in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
+ )
+ )
+ ch_in = m * n_channels
+ downs.append(Downsample(ch_in, with_conv=False))
+
+ self.model = nn.ModuleList(blocks)
+ self.downsampler = nn.ModuleList(downs)
+
+ def instantiate_pretrained(self, config):
+ model = instantiate_from_config(config)
+ self.pretrained_model = model.eval()
+ # self.pretrained_model.train = False
+ for param in self.pretrained_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def encode_with_pretrained(self, x):
+ c = self.pretrained_model.encode(x)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ return c
+
+ def forward(self, x):
+ z_fs = self.encode_with_pretrained(x)
+ z = self.proj_norm(z_fs)
+ z = self.proj(z)
+ z = nonlinearity(z)
+
+ for submodel, downmodel in zip(self.model, self.downsampler):
+ z = submodel(z, temb=None)
+ z = downmodel(z)
+
+ if self.do_reshape:
+ z = rearrange(z, "b c h w -> b (h w) c")
+ return z
diff --git a/inversion_utils.py b/inversion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2606a08af913705217165b4846205696575ad18
--- /dev/null
+++ b/inversion_utils.py
@@ -0,0 +1,450 @@
+import torch
+from tqdm import tqdm
+# from torchvision import transforms as T
+from typing import List, Optional, Dict, Union
+from models import PipelineWrapper
+
+
+def mu_tilde(model, xt, x0, timestep):
+ "mu_tilde(x_t, x_0) DDPM paper eq. 7"
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \
+ else model.scheduler.final_alpha_cumprod
+ alpha_t = model.scheduler.alphas[timestep]
+ beta_t = 1 - alpha_t
+ alpha_bar = model.scheduler.alphas_cumprod[timestep]
+ return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1-alpha_bar)) * x0 + \
+ ((alpha_t**0.5 * (1-alpha_prod_t_prev)) / (1 - alpha_bar)) * xt
+
+
+def sample_xts_from_x0(model, x0, num_inference_steps=50, x_prev_mode=False):
+ """
+ Samples from P(x_1:T|x_0)
+ """
+ # torch.manual_seed(43256465436)
+ alpha_bar = model.model.scheduler.alphas_cumprod
+ sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
+ alphas = model.model.scheduler.alphas
+ # betas = 1 - alphas
+ variance_noise_shape = (
+ num_inference_steps + 1,
+ model.model.unet.config.in_channels,
+ # model.unet.sample_size,
+ # model.unet.sample_size)
+ x0.shape[-2],
+ x0.shape[-1])
+
+ timesteps = model.model.scheduler.timesteps.to(model.device)
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
+ xts = torch.zeros(variance_noise_shape).to(x0.device)
+ xts[0] = x0
+ x_prev = x0
+ for t in reversed(timesteps):
+ # idx = t_to_idx[int(t)]
+ idx = num_inference_steps-t_to_idx[int(t)]
+ if x_prev_mode:
+ xts[idx] = x_prev * (alphas[t] ** 0.5) + torch.randn_like(x0) * ((1-alphas[t]) ** 0.5)
+ x_prev = xts[idx].clone()
+ else:
+ xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
+ # xts = torch.cat([xts, x0 ],dim = 0)
+
+ return xts
+
+
+def forward_step(model, model_output, timestep, sample):
+ next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
+ timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
+
+ # 2. compute alphas, betas
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
+ # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 \
+ # else self.scheduler.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+
+ # 5. TODO: simple noising implementatiom
+ next_sample = model.scheduler.add_noise(pred_original_sample, model_output, torch.LongTensor([next_timestep]))
+ return next_sample
+
+
+def inversion_forward_process(model: PipelineWrapper,
+ x0: torch.Tensor,
+ etas: Optional[float] = None,
+ prog_bar: bool = False,
+ prompts: List[str] = [""],
+ cfg_scales: List[float] = [3.5],
+ num_inference_steps: int = 50,
+ eps: Optional[float] = None,
+ cutoff_points: Optional[List[float]] = None,
+ numerical_fix: bool = False,
+ extract_h_space: bool = False,
+ extract_skipconns: bool = False,
+ x_prev_mode: bool = False):
+ if len(prompts) > 1 and extract_h_space:
+ raise NotImplementedError("How do you split cfg_scales for hspace? TODO")
+
+ if len(prompts) > 1 or prompts[0] != "":
+ text_embeddings_hidden_states, text_embeddings_class_labels, \
+ text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
+ # text_embeddings = encode_text(model, prompt)
+
+ # # classifier free guidance
+ batch_size = len(prompts)
+ cfg_scales_tensor = torch.ones((batch_size, *x0.shape[1:]), device=model.device, dtype=x0.dtype)
+
+ # if len(prompts) > 1:
+ # if cutoff_points is None:
+ # cutoff_points = [i * 1 / batch_size for i in range(1, batch_size)]
+ # if len(cfg_scales) == 1:
+ # cfg_scales *= batch_size
+ # elif len(cfg_scales) < batch_size:
+ # raise ValueError("Not enough target CFG scales")
+
+ # cutoff_points = [int(x * cfg_scales_tensor.shape[2]) for x in cutoff_points]
+ # cutoff_points = [0, *cutoff_points, cfg_scales_tensor.shape[2]]
+
+ # for i, (start, end) in enumerate(zip(cutoff_points[:-1], cutoff_points[1:])):
+ # cfg_scales_tensor[i, :, end:] = 0
+ # cfg_scales_tensor[i, :, :start] = 0
+ # cfg_scales_tensor[i] *= cfg_scales[i]
+ # if prompts[i] == "":
+ # cfg_scales_tensor[i] = 0
+ # cfg_scales_tensor = T.functional.gaussian_blur(cfg_scales_tensor, kernel_size=15, sigma=1)
+ # else:
+ cfg_scales_tensor *= cfg_scales[0]
+
+ uncond_embedding_hidden_states, uncond_embedding_class_lables, uncond_boolean_prompt_mask = model.encode_text([""])
+ # uncond_embedding = encode_text(model, "")
+ timesteps = model.model.scheduler.timesteps.to(model.device)
+ variance_noise_shape = (
+ num_inference_steps,
+ model.model.unet.config.in_channels,
+ # model.unet.sample_size,
+ # model.unet.sample_size)
+ x0.shape[-2],
+ x0.shape[-1])
+
+ if etas is None or (type(etas) in [int, float] and etas == 0):
+ eta_is_zero = True
+ zs = None
+ else:
+ eta_is_zero = False
+ if type(etas) in [int, float]:
+ etas = [etas]*model.model.scheduler.num_inference_steps
+ xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps, x_prev_mode=x_prev_mode)
+ alpha_bar = model.model.scheduler.alphas_cumprod
+ zs = torch.zeros(size=variance_noise_shape, device=model.device)
+ hspaces = []
+ skipconns = []
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
+ xt = x0
+ # op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
+ op = tqdm(timesteps) if prog_bar else timesteps
+
+ for t in op:
+ # idx = t_to_idx[int(t)]
+ idx = num_inference_steps - t_to_idx[int(t)] - 1
+ # 1. predict noise residual
+ if not eta_is_zero:
+ xt = xts[idx+1][None]
+
+ with torch.no_grad():
+ out, out_hspace, out_skipconns = model.unet_forward(xt, timestep=t,
+ encoder_hidden_states=uncond_embedding_hidden_states,
+ class_labels=uncond_embedding_class_lables,
+ encoder_attention_mask=uncond_boolean_prompt_mask)
+ # out = model.unet.forward(xt, timestep= t, encoder_hidden_states=uncond_embedding)
+ if len(prompts) > 1 or prompts[0] != "":
+ cond_out, cond_out_hspace, cond_out_skipconns = model.unet_forward(
+ xt.expand(len(prompts), -1, -1, -1), timestep=t,
+ encoder_hidden_states=text_embeddings_hidden_states,
+ class_labels=text_embeddings_class_labels,
+ encoder_attention_mask=text_embeddings_boolean_prompt_mask)
+ # cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states = text_embeddings)
+
+ if len(prompts) > 1 or prompts[0] != "":
+ # # classifier free guidance
+ noise_pred = out.sample + \
+ (cfg_scales_tensor * (cond_out.sample - out.sample.expand(batch_size, -1, -1, -1))
+ ).sum(axis=0).unsqueeze(0)
+ if extract_h_space or extract_skipconns:
+ noise_h_space = out_hspace + cfg_scales[0] * (cond_out_hspace - out_hspace)
+ if extract_skipconns:
+ noise_skipconns = {k: [out_skipconns[k][j] + cfg_scales[0] *
+ (cond_out_skipconns[k][j] - out_skipconns[k][j])
+ for j in range(len(out_skipconns[k]))]
+ for k in out_skipconns}
+ else:
+ noise_pred = out.sample
+ if extract_h_space or extract_skipconns:
+ noise_h_space = out_hspace
+ if extract_skipconns:
+ noise_skipconns = out_skipconns
+ if extract_h_space or extract_skipconns:
+ hspaces.append(noise_h_space)
+ if extract_skipconns:
+ skipconns.append(noise_skipconns)
+
+ if eta_is_zero:
+ # 2. compute more noisy image and set x_t -> x_t+1
+ xt = forward_step(model.model, noise_pred, t, xt)
+ else:
+ # xtm1 = xts[idx+1][None]
+ xtm1 = xts[idx][None]
+ # pred of x0
+ if model.model.scheduler.config.prediction_type == 'epsilon':
+ pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
+ elif model.model.scheduler.config.prediction_type == 'v_prediction':
+ pred_original_sample = (alpha_bar[t] ** 0.5) * xt - ((1 - alpha_bar[t]) ** 0.5) * noise_pred
+
+ # direction to xt
+ prev_timestep = t - model.model.scheduler.config.num_train_timesteps // \
+ model.model.scheduler.num_inference_steps
+
+ alpha_prod_t_prev = model.get_alpha_prod_t_prev(prev_timestep)
+ variance = model.get_variance(t, prev_timestep)
+
+ if model.model.scheduler.config.prediction_type == 'epsilon':
+ radom_noise_pred = noise_pred
+ elif model.model.scheduler.config.prediction_type == 'v_prediction':
+ radom_noise_pred = (alpha_bar[t] ** 0.5) * noise_pred + ((1 - alpha_bar[t]) ** 0.5) * xt
+
+ pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * radom_noise_pred
+
+ mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
+
+ zs[idx] = z
+
+ # correction to avoid error accumulation
+ if numerical_fix:
+ xtm1 = mu_xt + (etas[idx] * variance ** 0.5)*z
+ xts[idx] = xtm1
+
+ if zs is not None:
+ # zs[-1] = torch.zeros_like(zs[-1])
+ zs[0] = torch.zeros_like(zs[0])
+ # zs_cycle[0] = torch.zeros_like(zs[0])
+
+ if extract_h_space:
+ hspaces = torch.concat(hspaces, axis=0)
+ return xt, zs, xts, hspaces
+
+ if extract_skipconns:
+ hspaces = torch.concat(hspaces, axis=0)
+ return xt, zs, xts, hspaces, skipconns
+
+ return xt, zs, xts
+
+
+def reverse_step(model, model_output, timestep, sample, eta=0, variance_noise=None):
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - model.model.scheduler.config.num_train_timesteps // \
+ model.model.scheduler.num_inference_steps
+ # 2. compute alphas, betas
+ alpha_prod_t = model.model.scheduler.alphas_cumprod[timestep]
+ alpha_prod_t_prev = model.get_alpha_prod_t_prev(prev_timestep)
+ beta_prod_t = 1 - alpha_prod_t
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ if model.model.scheduler.config.prediction_type == 'epsilon':
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif model.model.scheduler.config.prediction_type == 'v_prediction':
+ pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ # variance = self.scheduler._get_variance(timestep, prev_timestep)
+ variance = model.get_variance(timestep, prev_timestep)
+ # std_dev_t = eta * variance ** (0.5)
+ # Take care of asymetric reverse process (asyrp)
+ if model.model.scheduler.config.prediction_type == 'epsilon':
+ model_output_direction = model_output
+ elif model.model.scheduler.config.prediction_type == 'v_prediction':
+ model_output_direction = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
+ pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+ # 8. Add noice if eta > 0
+ if eta > 0:
+ if variance_noise is None:
+ variance_noise = torch.randn(model_output.shape, device=model.device)
+ sigma_z = eta * variance ** (0.5) * variance_noise
+ prev_sample = prev_sample + sigma_z
+
+ return prev_sample
+
+
+def inversion_reverse_process(model: PipelineWrapper,
+ xT: torch.Tensor,
+ skips: torch.Tensor,
+ fix_alpha: float = 0.1,
+ etas: float = 0,
+ prompts: List[str] = [""],
+ neg_prompts: List[str] = [""],
+ cfg_scales: Optional[List[float]] = None,
+ prog_bar: bool = False,
+ zs: Optional[List[torch.Tensor]] = None,
+ # controller=None,
+ cutoff_points: Optional[List[float]] = None,
+ hspace_add: Optional[torch.Tensor] = None,
+ hspace_replace: Optional[torch.Tensor] = None,
+ skipconns_replace: Optional[Dict[int, torch.Tensor]] = None,
+ zero_out_resconns: Optional[Union[int, List]] = None,
+ asyrp: bool = False,
+ extract_h_space: bool = False,
+ extract_skipconns: bool = False):
+
+ batch_size = len(prompts)
+
+ text_embeddings_hidden_states, text_embeddings_class_labels, \
+ text_embeddings_boolean_prompt_mask = model.encode_text(prompts)
+ uncond_embedding_hidden_states, uncond_embedding_class_lables, \
+ uncond_boolean_prompt_mask = model.encode_text(neg_prompts)
+ # text_embeddings = encode_text(model, prompts)
+ # uncond_embedding = encode_text(model, [""] * batch_size)
+
+ masks = torch.ones((batch_size, *xT.shape[1:]), device=model.device, dtype=xT.dtype)
+ cfg_scales_tensor = torch.ones((batch_size, *xT.shape[1:]), device=model.device, dtype=xT.dtype)
+
+ # if batch_size > 1:
+ # if cutoff_points is None:
+ # cutoff_points = [i * 1 / batch_size for i in range(1, batch_size)]
+ # if len(cfg_scales) == 1:
+ # cfg_scales *= batch_size
+ # elif len(cfg_scales) < batch_size:
+ # raise ValueError("Not enough target CFG scales")
+
+ # cutoff_points = [int(x * cfg_scales_tensor.shape[2]) for x in cutoff_points]
+ # cutoff_points = [0, *cutoff_points, cfg_scales_tensor.shape[2]]
+
+ # for i, (start, end) in enumerate(zip(cutoff_points[:-1], cutoff_points[1:])):
+ # cfg_scales_tensor[i, :, end:] = 0
+ # cfg_scales_tensor[i, :, :start] = 0
+ # masks[i, :, end:] = 0
+ # masks[i, :, :start] = 0
+ # cfg_scales_tensor[i] *= cfg_scales[i]
+ # cfg_scales_tensor = T.functional.gaussian_blur(cfg_scales_tensor, kernel_size=15, sigma=1)
+ # masks = T.functional.gaussian_blur(masks, kernel_size=15, sigma=1)
+ # else:
+ cfg_scales_tensor *= cfg_scales[0]
+
+ if etas is None:
+ etas = 0
+ if type(etas) in [int, float]:
+ etas = [etas]*model.model.scheduler.num_inference_steps
+ assert len(etas) == model.model.scheduler.num_inference_steps
+ timesteps = model.model.scheduler.timesteps.to(model.device)
+
+ # xt = xT.expand(1, -1, -1, -1)
+ xt = xT[skips.max()].unsqueeze(0)
+ op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
+
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs.shape[0]:])}
+ hspaces = []
+ skipconns = []
+
+ for it, t in enumerate(op):
+ # idx = t_to_idx[int(t)]
+ idx = model.model.scheduler.num_inference_steps - t_to_idx[int(t)] - \
+ (model.model.scheduler.num_inference_steps - zs.shape[0] + 1)
+ # # Unconditional embedding
+ with torch.no_grad():
+ uncond_out, out_hspace, out_skipconns = model.unet_forward(
+ xt, timestep=t,
+ encoder_hidden_states=uncond_embedding_hidden_states,
+ class_labels=uncond_embedding_class_lables,
+ encoder_attention_mask=uncond_boolean_prompt_mask,
+ mid_block_additional_residual=(None if hspace_add is None else
+ (1 / (cfg_scales[0] + 1)) *
+ (hspace_add[-zs.shape[0]:][it] if hspace_add.shape[0] > 1
+ else hspace_add)),
+ replace_h_space=(None if hspace_replace is None else
+ (hspace_replace[-zs.shape[0]:][it].unsqueeze(0) if hspace_replace.shape[0] > 1
+ else hspace_replace)),
+ zero_out_resconns=zero_out_resconns,
+ replace_skip_conns=(None if skipconns_replace is None else
+ (skipconns_replace[-zs.shape[0]:][it] if len(skipconns_replace) > 1
+ else skipconns_replace))
+ ) # encoder_hidden_states = uncond_embedding)
+
+ # # Conditional embedding
+ if prompts:
+ with torch.no_grad():
+ cond_out, cond_out_hspace, cond_out_skipconns = model.unet_forward(
+ xt.expand(batch_size, -1, -1, -1),
+ timestep=t,
+ encoder_hidden_states=text_embeddings_hidden_states,
+ class_labels=text_embeddings_class_labels,
+ encoder_attention_mask=text_embeddings_boolean_prompt_mask,
+ mid_block_additional_residual=(None if hspace_add is None else
+ (cfg_scales[0] / (cfg_scales[0] + 1)) *
+ (hspace_add[-zs.shape[0]:][it] if hspace_add.shape[0] > 1
+ else hspace_add)),
+ replace_h_space=(None if hspace_replace is None else
+ (hspace_replace[-zs.shape[0]:][it].unsqueeze(0) if hspace_replace.shape[0] > 1
+ else hspace_replace)),
+ zero_out_resconns=zero_out_resconns,
+ replace_skip_conns=(None if skipconns_replace is None else
+ (skipconns_replace[-zs.shape[0]:][it] if len(skipconns_replace) > 1
+ else skipconns_replace))
+ ) # encoder_hidden_states = text_embeddings)
+
+ z = zs[idx] if zs is not None else None
+ # print(f'idx: {idx}')
+ # print(f't: {t}')
+ z = z.unsqueeze(0)
+ # z = z.expand(batch_size, -1, -1, -1)
+ if prompts:
+ # # classifier free guidance
+ # noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
+ noise_pred = uncond_out.sample + \
+ (cfg_scales_tensor * (cond_out.sample - uncond_out.sample.expand(batch_size, -1, -1, -1))
+ ).sum(axis=0).unsqueeze(0)
+ if extract_h_space or extract_skipconns:
+ noise_h_space = out_hspace + cfg_scales[0] * (cond_out_hspace - out_hspace)
+ if extract_skipconns:
+ noise_skipconns = {k: [out_skipconns[k][j] + cfg_scales[0] *
+ (cond_out_skipconns[k][j] - out_skipconns[k][j])
+ for j in range(len(out_skipconns[k]))]
+ for k in out_skipconns}
+ else:
+ noise_pred = uncond_out.sample
+ if extract_h_space or extract_skipconns:
+ noise_h_space = out_hspace
+ if extract_skipconns:
+ noise_skipconns = out_skipconns
+
+ if extract_h_space or extract_skipconns:
+ hspaces.append(noise_h_space)
+ if extract_skipconns:
+ skipconns.append(noise_skipconns)
+
+ # 2. compute less noisy image and set x_t -> x_t-1
+ xt = reverse_step(model, noise_pred, t, xt, eta=etas[idx], variance_noise=z)
+ # if controller is not None:
+ # xt = controller.step_callback(xt)
+
+ # "fix" xt
+ apply_fix = ((skips.max() - skips) > it)
+ if apply_fix.any():
+ apply_fix = (apply_fix * fix_alpha).unsqueeze(1).unsqueeze(2).unsqueeze(3).to(xT.device)
+ xt = (masks * (xt.expand(batch_size, -1, -1, -1) * (1 - apply_fix) +
+ apply_fix * xT[skips.max() - it - 1].expand(batch_size, -1, -1, -1))
+ ).sum(axis=0).unsqueeze(0)
+
+ if extract_h_space:
+ return xt, zs, torch.concat(hspaces, axis=0)
+
+ if extract_skipconns:
+ return xt, zs, torch.concat(hspaces, axis=0), skipconns
+
+ return xt, zs
diff --git a/models.py b/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..574ad9cb8f5b2c07b6edebabd1b62ac61344343c
--- /dev/null
+++ b/models.py
@@ -0,0 +1,644 @@
+import torch
+from diffusers import DDIMScheduler
+from diffusers import AudioLDM2Pipeline
+from transformers import RobertaTokenizer, RobertaTokenizerFast
+from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+
+class PipelineWrapper(torch.nn.Module):
+ def __init__(self, model_id, device, double_precision=False, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.model_id = model_id
+ self.device = device
+ self.double_precision = double_precision
+
+ def get_sigma(self, timestep) -> float:
+ sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.model.scheduler.alphas_cumprod - 1)
+ return sqrt_recipm1_alphas_cumprod[timestep]
+
+ def load_scheduler(self):
+ pass
+
+ def get_fn_STFT(self):
+ pass
+
+ def vae_encode(self, x: torch.Tensor):
+ pass
+
+ def vae_decode(self, x: torch.Tensor):
+ pass
+
+ def decode_to_mel(self, x: torch.Tensor):
+ pass
+
+ def encode_text(self, prompts: List[str]) -> Tuple:
+ pass
+
+ def get_variance(self, timestep, prev_timestep):
+ pass
+
+ def get_alpha_prod_t_prev(self, prev_timestep):
+ pass
+
+ def unet_forward(self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ replace_h_space: Optional[torch.Tensor] = None,
+ replace_skip_conns: Optional[Dict[int, torch.Tensor]] = None,
+ return_dict: bool = True,
+ zero_out_resconns: Optional[Union[int, List]] = None) -> Tuple:
+
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.model.unet.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ # logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.model.unet.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.model.unet.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.model.unet.time_embedding(t_emb, timestep_cond)
+
+ if self.model.unet.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.model.unet.config.class_embed_type == "timestep":
+ class_labels = self.model.unet.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.model.unet.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.model.unet.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ if self.model.unet.config.addition_embed_type == "text":
+ aug_emb = self.model.unet.add_embedding(encoder_hidden_states)
+ emb = emb + aug_emb
+ elif self.model.unet.config.addition_embed_type == "text_image":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.model.unet.__class__} has the config param `addition_embed_type` set to 'text_image' "
+ f"which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+
+ aug_emb = self.model.unet.add_embedding(text_embs, image_embs)
+ emb = emb + aug_emb
+
+ if self.model.unet.time_embed_act is not None:
+ emb = self.model.unet.time_embed_act(emb)
+
+ if self.model.unet.encoder_hid_proj is not None and self.model.unet.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.model.unet.encoder_hid_proj(encoder_hidden_states)
+ elif self.model.unet.encoder_hid_proj is not None and \
+ self.model.unet.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.model.unet.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' "
+ f"which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.model.unet.encoder_hid_proj(encoder_hidden_states, image_embeds)
+
+ # 2. pre-process
+ sample = self.model.unet.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.model.unet.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ if down_block_additional_residuals is not None:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.model.unet.mid_block is not None:
+ sample = self.model.unet.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ # print(sample.shape)
+
+ if replace_h_space is None:
+ h_space = sample.clone()
+ else:
+ h_space = replace_h_space
+ sample = replace_h_space.clone()
+
+ if mid_block_additional_residual is not None:
+ sample = sample + mid_block_additional_residual
+
+ extracted_res_conns = {}
+ # 5. up
+ for i, upsample_block in enumerate(self.model.unet.up_blocks):
+ is_final_block = i == len(self.model.unet.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+ if replace_skip_conns is not None and replace_skip_conns.get(i):
+ res_samples = replace_skip_conns.get(i)
+
+ if zero_out_resconns is not None:
+ if (type(zero_out_resconns) is int and i >= (zero_out_resconns - 1)) or \
+ type(zero_out_resconns) is list and i in zero_out_resconns:
+ res_samples = [torch.zeros_like(x) for x in res_samples]
+ # down_block_res_samples = [torch.zeros_like(x) for x in down_block_res_samples]
+
+ extracted_res_conns[i] = res_samples
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ # 6. post-process
+ if self.model.unet.conv_norm_out:
+ sample = self.model.unet.conv_norm_out(sample)
+ sample = self.model.unet.conv_act(sample)
+ sample = self.model.unet.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
+
+
+class AudioLDM2Wrapper(PipelineWrapper):
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ if self.double_precision:
+ self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, torch_dtype=torch.float64).to(self.device)
+ else:
+ try:
+ self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=True).to(self.device)
+ except FileNotFoundError:
+ self.model = AudioLDM2Pipeline.from_pretrained(self.model_id, local_files_only=False).to(self.device)
+
+ def load_scheduler(self):
+ # self.model.scheduler = DDIMScheduler.from_config(self.model_id, subfolder="scheduler")
+ self.model.scheduler = DDIMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
+
+ def get_fn_STFT(self):
+ from audioldm.audio import TacotronSTFT
+ return TacotronSTFT(
+ filter_length=1024,
+ hop_length=160,
+ win_length=1024,
+ n_mel_channels=64,
+ sampling_rate=16000,
+ mel_fmin=0,
+ mel_fmax=8000,
+ )
+
+ def vae_encode(self, x):
+ # self.model.vae.disable_tiling()
+ if x.shape[2] % 4:
+ x = torch.nn.functional.pad(x, (0, 0, 4 - (x.shape[2] % 4), 0))
+ return (self.model.vae.encode(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
+ # return (self.encode_no_tiling(x).latent_dist.mode() * self.model.vae.config.scaling_factor).float()
+
+ def vae_decode(self, x):
+ return self.model.vae.decode(1 / self.model.vae.config.scaling_factor * x).sample
+
+ def decode_to_mel(self, x):
+ if self.double_precision:
+ tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().double()).detach()
+ tmp = self.model.mel_spectrogram_to_waveform(x[:, 0].detach().float()).detach()
+ if len(tmp.shape) == 1:
+ tmp = tmp.unsqueeze(0)
+ return tmp
+
+ def encode_text(self, prompts: List[str]):
+ tokenizers = [self.model.tokenizer, self.model.tokenizer_2]
+ text_encoders = [self.model.text_encoder, self.model.text_encoder_2]
+ prompt_embeds_list = []
+ attention_mask_list = []
+
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ prompts,
+ padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True,
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ attention_mask = text_inputs.attention_mask
+ untruncated_ids = tokenizer(prompts, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] \
+ and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = tokenizer.batch_decode(
+ untruncated_ids[:, tokenizer.model_max_length - 1: -1])
+ print(f"The following part of your input was truncated because {text_encoder.config.model_type} can "
+ f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ text_input_ids = text_input_ids.to(self.device)
+ attention_mask = attention_mask.to(self.device)
+
+ with torch.no_grad():
+ if text_encoder.config.model_type == "clap":
+ prompt_embeds = text_encoder.get_text_features(
+ text_input_ids,
+ attention_mask=attention_mask,
+ )
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
+ prompt_embeds = prompt_embeds[:, None, :]
+ # make sure that we attend to this single hidden-state
+ attention_mask = attention_mask.new_ones((len(prompts), 1))
+ else:
+ prompt_embeds = text_encoder(
+ text_input_ids,
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds_list.append(prompt_embeds)
+ attention_mask_list.append(attention_mask)
+
+ # print(f'prompt[0].shape: {prompt_embeds_list[0].shape}')
+ # print(f'prompt[1].shape: {prompt_embeds_list[1].shape}')
+ # print(f'attn[0].shape: {attention_mask_list[0].shape}')
+ # print(f'attn[1].shape: {attention_mask_list[1].shape}')
+
+ projection_output = self.model.projection_model(
+ hidden_states=prompt_embeds_list[0],
+ hidden_states_1=prompt_embeds_list[1],
+ attention_mask=attention_mask_list[0],
+ attention_mask_1=attention_mask_list[1],
+ )
+ projected_prompt_embeds = projection_output.hidden_states
+ projected_attention_mask = projection_output.attention_mask
+
+ generated_prompt_embeds = self.model.generate_language_model(
+ projected_prompt_embeds,
+ attention_mask=projected_attention_mask,
+ max_new_tokens=None,
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=self.model.text_encoder_2.dtype, device=self.device)
+ attention_mask = (
+ attention_mask.to(device=self.device)
+ if attention_mask is not None
+ else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=self.device)
+ )
+ generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.model.language_model.dtype, device=self.device)
+
+ return generated_prompt_embeds, prompt_embeds, attention_mask
+
+ def get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.model.scheduler.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.get_alpha_prod_t_prev(prev_timestep)
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+ return variance
+
+ def get_alpha_prod_t_prev(self, prev_timestep):
+ return self.model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 \
+ else self.model.scheduler.final_alpha_cumprod
+
+ def unet_forward(self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ timestep_cond: Optional[torch.Tensor] = None,
+ class_labels: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ replace_h_space: Optional[torch.Tensor] = None,
+ replace_skip_conns: Optional[Dict[int, torch.Tensor]] = None,
+ zero_out_resconns: Optional[Union[int, List]] = None) -> Tuple:
+
+ # Translation
+ encoder_hidden_states_1 = class_labels
+ class_labels = None
+ encoder_attention_mask_1 = encoder_attention_mask
+ encoder_attention_mask = None
+
+ # return self.model.unet(sample, timestep,
+ # encoder_hidden_states=generated_prompt_embeds,
+ # encoder_hidden_states_1=encoder_hidden_states_1,
+ # encoder_attention_mask_1=encoder_attention_mask_1,
+ # ), None, None
+
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2 ** self.model.unet.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ # print("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ if encoder_attention_mask_1 is not None:
+ encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0
+ encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.model.unet.time_proj(timesteps)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.model.unet.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.model.unet.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.model.unet.config.class_embed_type == "timestep":
+ class_labels = self.model.unet.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.model.unet.class_embedding(class_labels).to(dtype=sample.dtype)
+
+ if self.model.unet.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.model.unet.time_embed_act is not None:
+ emb = self.model.unet.time_embed_act(emb)
+
+ # 2. pre-process
+ sample = self.model.unet.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.model.unet.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ encoder_hidden_states_1=encoder_hidden_states_1,
+ encoder_attention_mask_1=encoder_attention_mask_1,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.model.unet.mid_block is not None:
+ sample = self.model.unet.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ encoder_hidden_states_1=encoder_hidden_states_1,
+ encoder_attention_mask_1=encoder_attention_mask_1,
+ )
+
+ if replace_h_space is None:
+ h_space = sample.clone()
+ else:
+ h_space = replace_h_space
+ sample = replace_h_space.clone()
+
+ if mid_block_additional_residual is not None:
+ sample = sample + mid_block_additional_residual
+
+ extracted_res_conns = {}
+ # 5. up
+ for i, upsample_block in enumerate(self.model.unet.up_blocks):
+ is_final_block = i == len(self.model.unet.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+ if replace_skip_conns is not None and replace_skip_conns.get(i):
+ res_samples = replace_skip_conns.get(i)
+
+ if zero_out_resconns is not None:
+ if (type(zero_out_resconns) is int and i >= (zero_out_resconns - 1)) or \
+ type(zero_out_resconns) is list and i in zero_out_resconns:
+ res_samples = [torch.zeros_like(x) for x in res_samples]
+ # down_block_res_samples = [torch.zeros_like(x) for x in down_block_res_samples]
+
+ extracted_res_conns[i] = res_samples
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ encoder_hidden_states_1=encoder_hidden_states_1,
+ encoder_attention_mask_1=encoder_attention_mask_1,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ # 6. post-process
+ if self.model.unet.conv_norm_out:
+ sample = self.model.unet.conv_norm_out(sample)
+ sample = self.model.unet.conv_act(sample)
+ sample = self.model.unet.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample), h_space, extracted_res_conns
+
+ def forward(self, *args, **kwargs):
+ return self
+
+
+def load_model(model_id, device, num_diffusion_steps, double_precision=False):
+ ldm_stable = AudioLDM2Wrapper(model_id=model_id, device=device, double_precision=double_precision)
+ ldm_stable.load_scheduler()
+ ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device)
+ torch.cuda.empty_cache()
+ # controller = AttentionStore()
+ # controller = EmptyControl()
+ # register_attention_control(ldm_stable.model, controller)
+ # return ldm_stable, controller
+ return ldm_stable
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7fb3722dcd2f7a0b55a4f6ac5f5b7b2e7d824254
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,11 @@
+torch
+torchaudio
+diffusers
+accelerate
+transformers
+tqdm
+soundfile
+progressbar
+einops
+scipy
+librosa==0.9.2
diff --git a/style.css b/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..1d6eedacc8cb83b31c9c0072b37a886e32308f63
--- /dev/null
+++ b/style.css
@@ -0,0 +1,4 @@
+.gradio-container {
+ max-width: 1050px !important;
+ padding-top: 1.5rem !important;
+}
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..52a337c2250026cca334f26ae9357d912495e642
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,71 @@
+import numpy as np
+import torch
+from typing import Optional, List, Tuple, NamedTuple, Union
+from models import PipelineWrapper
+
+
+class PromptEmbeddings(NamedTuple):
+ embedding_hidden_states: torch.Tensor
+ embedding_class_lables: torch.Tensor
+ boolean_prompt_mask: torch.Tensor
+
+
+def load_audio(audio_path: Union[str, np.array], fn_STFT, left: int = 0, right: int = 0, device: Optional[torch.device] = None
+ ) -> torch.tensor:
+ if type(audio_path) is str:
+ import audioldm
+ import audioldm.audio
+
+ duration = audioldm.utils.get_duration(audio_path)
+
+ mel, _, _ = audioldm.audio.wav_to_fbank(audio_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT)
+ mel = mel.unsqueeze(0)
+ else:
+ mel = audio_path
+
+ c, h, w = mel.shape
+ left = min(left, w-1)
+ right = min(right, w - left - 1)
+ mel = mel[:, :, left:w-right]
+ mel = mel.unsqueeze(0).to(device)
+
+ return mel
+
+
+def get_height_of_spectrogram(length: int, ldm_stable: PipelineWrapper) -> int:
+ vocoder_upsample_factor = np.prod(ldm_stable.model.vocoder.config.upsample_rates) / \
+ ldm_stable.model.vocoder.config.sampling_rate
+
+ if length is None:
+ length = ldm_stable.model.unet.config.sample_size * ldm_stable.model.vae_scale_factor * \
+ vocoder_upsample_factor
+
+ height = int(length / vocoder_upsample_factor)
+
+ # original_waveform_length = int(length * ldm_stable.model.vocoder.config.sampling_rate)
+ if height % ldm_stable.model.vae_scale_factor != 0:
+ height = int(np.ceil(height / ldm_stable.model.vae_scale_factor)) * ldm_stable.model.vae_scale_factor
+ print(
+ f"Audio length in seconds {length} is increased to {height * vocoder_upsample_factor} "
+ f"so that it can be handled by the model. It will be cut to {length} after the "
+ f"denoising process."
+ )
+
+ return height
+
+
+def get_text_embeddings(target_prompt: List[str], target_neg_prompt: List[str], ldm_stable: PipelineWrapper
+ ) -> Tuple[torch.Tensor, PromptEmbeddings, PromptEmbeddings]:
+ text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = \
+ ldm_stable.encode_text(target_prompt)
+ uncond_embedding_hidden_states, uncond_embedding_class_lables, uncond_boolean_prompt_mask = \
+ ldm_stable.encode_text(target_neg_prompt)
+
+ text_emb = PromptEmbeddings(embedding_hidden_states=text_embeddings_hidden_states,
+ boolean_prompt_mask=text_embeddings_boolean_prompt_mask,
+ embedding_class_lables=text_embeddings_class_labels)
+ uncond_emb = PromptEmbeddings(embedding_hidden_states=uncond_embedding_hidden_states,
+ boolean_prompt_mask=uncond_boolean_prompt_mask,
+ embedding_class_lables=uncond_embedding_class_lables)
+
+ return text_embeddings_class_labels, text_emb, uncond_emb