|
import os, json |
|
import math, random |
|
from multiprocessing import Pool |
|
from tqdm import tqdm |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
import matplotlib.pyplot as plt |
|
from torchvision import transforms |
|
from transformers import CLIPTextModel |
|
from transformers import PretrainedConfig |
|
|
|
|
|
def pad_spec(spec, spec_length, pad_value=0, random_crop=True): |
|
assert spec_length % 8 == 0, "spec_length must be divisible by 8" |
|
if spec.shape[-1] < spec_length: |
|
|
|
spec = F.pad(spec, (0, spec_length - spec.shape[-1]), value=pad_value) |
|
else: |
|
|
|
if random_crop: |
|
start = random.randint(0, spec.shape[-1] - spec_length) |
|
spec = spec[:, :, start:start+spec_length] |
|
else: |
|
spec = spec[:, :, :spec_length] |
|
return spec |
|
|
|
|
|
def load_spec(spec_path): |
|
if spec_path.endswith(".pt"): |
|
spec = torch.load(spec_path, map_location="cpu") |
|
elif spec_path.endswith(".npy"): |
|
spec = torch.from_numpy(np.load(spec_path)) |
|
else: |
|
raise ValueError(f"Unknown spec file type {spec_path}") |
|
assert len(spec.shape) == 3, f"spec shape must be [3, mel_dim, spec_len], got {spec.shape}" |
|
if spec.size(0) == 1: |
|
spec = spec.repeat(3, 1, 1) |
|
return spec |
|
|
|
|
|
def random_crop_spec(spec, target_spec_length, pad_value=0, frame_per_sec=100, time_step=5): |
|
assert target_spec_length % 8 == 0, "spec_length must be divisible by 8" |
|
|
|
spec_length = spec.shape[-1] |
|
full_s = math.ceil(spec_length / frame_per_sec / time_step) * time_step |
|
start_s = random.randint(0, math.floor(spec_length / frame_per_sec / time_step)) * time_step |
|
|
|
end_s = min(start_s + math.ceil(target_spec_length / frame_per_sec), full_s) |
|
|
|
spec = spec[:, :, start_s * frame_per_sec : end_s * frame_per_sec] |
|
|
|
if spec.shape[-1] < target_spec_length: |
|
spec = F.pad(spec, (0, target_spec_length - spec.shape[-1]), value=pad_value) |
|
else: |
|
spec = spec[:, :, :target_spec_length] |
|
|
|
return spec, int(start_s), int(end_s), int(full_s) |
|
|
|
|
|
|
|
def load_condion_embed(text_embed_path): |
|
if text_embed_path.endswith(".pt"): |
|
text_embed_list = torch.load(text_embed_path, map_location="cpu") |
|
elif text_embed_path.endswith(".npy"): |
|
text_embed_list = torch.from_numpy(np.load(text_embed_path)) |
|
else: |
|
raise ValueError(f"Unknown text embedding file type {text_embed_path}") |
|
if type(text_embed_list) == list: |
|
text_embed = random.choice(text_embed_list) |
|
if len(text_embed.shape) == 3: |
|
text_embed = text_embed.squeeze(0) |
|
return text_embed.detach().cpu() |
|
|
|
|
|
def process_condition_embed(cond_emb, max_length): |
|
if cond_emb.shape[0] < max_length: |
|
cond_emb = F.pad(cond_emb, (0, 0, 0, max_length - cond_emb.shape[0]), value=0) |
|
else: |
|
cond_emb = cond_emb[:max_length, :] |
|
return cond_emb |
|
|
|
|
|
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str): |
|
text_encoder_config = PretrainedConfig.from_pretrained( |
|
pretrained_model_name_or_path |
|
) |
|
model_class = text_encoder_config.architectures[0] |
|
|
|
if model_class == "CLIPTextModel": |
|
from transformers import CLIPTextModel |
|
return CLIPTextModel |
|
if "t5" in model_class.lower(): |
|
from transformers import T5EncoderModel |
|
return T5EncoderModel |
|
if "clap" in model_class.lower(): |
|
from transformers import ClapTextModelWithProjection |
|
return ClapTextModelWithProjection |
|
else: |
|
raise ValueError(f"{model_class} is not supported.") |
|
|
|
|
|
|
|
def str2bool(string): |
|
str2val = {"True": True, "False": False, "true": True, "false": False, "none": False, "None": False} |
|
if string in str2val: |
|
return str2val[string] |
|
else: |
|
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") |
|
|
|
|
|
def str2str(string): |
|
if string.lower() == "none" or string.lower() == "null" or string.lower() == "false" or string == "": |
|
return None |
|
else: |
|
return string |
|
|
|
|
|
def json_dump(data_json, json_save_path): |
|
with open(json_save_path, 'w') as f: |
|
json.dump(data_json, f, indent=4) |
|
f.close() |
|
|
|
|
|
def json_load(json_path): |
|
with open(json_path, 'r') as f: |
|
data = json.load(f) |
|
f.close() |
|
return data |
|
|
|
|
|
def load_json_list(path): |
|
with open(path, 'r', encoding='utf-8') as f: |
|
return [json.loads(line) for line in f.readlines()] |
|
|
|
|
|
def save_json_list(data, path): |
|
with open(path, 'w', encoding='utf-8') as f: |
|
for d in data: |
|
f.write(json.dumps(d) + '\n') |
|
|
|
|
|
def multiprocess_function(func, func_args, n_jobs=32): |
|
with Pool(processes=n_jobs) as p: |
|
with tqdm(total=len(func_args)) as pbar: |
|
for i, _ in enumerate(p.imap_unordered(func, func_args)): |
|
pbar.update() |
|
|
|
|
|
def image_add_color(spec_img): |
|
cmap = plt.get_cmap('viridis') |
|
cmap_r = cmap.reversed() |
|
image = cmap(np.array(spec_img)[:,:,0])[:, :, :3] |
|
image = (image - image.min()) / (image.max() - image.min()) |
|
image = Image.fromarray(np.uint8(image*255)) |
|
return image |
|
|
|
|
|
@staticmethod |
|
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: |
|
""" |
|
Convert a PyTorch tensor to a NumPy image. |
|
""" |
|
images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
|
return images |
|
|
|
|
|
def numpy_to_pil(images): |
|
""" |
|
Convert a numpy image or a batch of images to a PIL image. |
|
""" |
|
if images.ndim == 3: |
|
images = images[None, ...] |
|
images = (images * 255).round().astype("uint8") |
|
if images.shape[-1] == 1: |
|
|
|
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] |
|
else: |
|
pil_images = [Image.fromarray(image) for image in images] |
|
|
|
return pil_images |
|
|
|
|
|
def normalize(images): |
|
""" |
|
Normalize an image array to [-1,1]. |
|
""" |
|
if images.min() >= 0: |
|
return 2.0 * images - 1.0 |
|
else: |
|
return images |
|
|
|
def denormalize(images): |
|
""" |
|
Denormalize an image array to [0,1]. |
|
""" |
|
if images.min() < 0: |
|
return (images / 2 + 0.5).clamp(0, 1) |
|
else: |
|
return images.clamp(0, 1) |
|
|
|
|
|
def prepare_mask_and_masked_image(image, mask): |
|
""" |
|
Prepare a binary mask and the masked image. |
|
|
|
Parameters: |
|
- image (torch.Tensor): The input image tensor of shape [3, height, width] with values in the range [0, 1]. |
|
- mask (torch.Tensor): The input mask tensor of shape [1, height, width]. |
|
|
|
Returns: |
|
- tuple: A tuple containing the binary mask and the masked image. |
|
""" |
|
|
|
if image.max() > 1: |
|
image = (image - image.min()) / (image.max() - image.min()) |
|
|
|
if image.min() >= 0: |
|
image = normalize(image) |
|
|
|
masked_image = image * (mask < 0.5) |
|
|
|
return mask, masked_image |
|
|
|
|
|
def torch_to_pil(image): |
|
""" |
|
Convert a torch tensor to a PIL image. |
|
""" |
|
if image.min() < 0: |
|
image = denormalize(image) |
|
|
|
return transforms.ToPILImage()(image.cpu().detach().squeeze()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConditionAdapter(nn.Module): |
|
def __init__(self, config): |
|
super(ConditionAdapter, self).__init__() |
|
self.config = config |
|
self.proj = nn.Linear(self.config["condition_dim"], self.config["cross_attention_dim"]) |
|
self.norm = torch.nn.LayerNorm(self.config["cross_attention_dim"]) |
|
print(f"INITIATED: ConditionAdapter: {self.config}") |
|
|
|
def forward(self, x): |
|
x = self.proj(x) |
|
x = self.norm(x) |
|
return x |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path): |
|
config_path = os.path.join(pretrained_model_name_or_path, "config.json") |
|
ckpt_path = os.path.join(pretrained_model_name_or_path, "condition_adapter.pt") |
|
config = json_load(config_path) |
|
instance = cls(config) |
|
instance.load_state_dict(torch.load(ckpt_path)) |
|
print(f"LOADED: ConditionAdapter from {pretrained_model_name_or_path}") |
|
return instance |
|
|
|
def save_pretrained(self, pretrained_model_name_or_path): |
|
os.makedirs(pretrained_model_name_or_path, exist_ok=True) |
|
config_path = os.path.join(pretrained_model_name_or_path, "config.json") |
|
ckpt_path = os.path.join(pretrained_model_name_or_path, "condition_adapter.pt") |
|
json_dump(self.config, config_path) |
|
torch.save(self.state_dict(), ckpt_path) |
|
print(f"SAVED: ConditionAdapter {self.config['condition_adapter_name']} to {pretrained_model_name_or_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|