diff --git "a/Vaani/vaani_scripts/CSIP.py" "b/Vaani/vaani_scripts/CSIP.py" new file mode 100644--- /dev/null +++ "b/Vaani/vaani_scripts/CSIP.py" @@ -0,0 +1,7044 @@ +# ================================================================== +# C S I P +# ================================================================== +# Author:: ASHISH KUMAR UCHADIYA +# Date:: 2024-May-27 +# +#<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + +# ================================================================== +# I M P O R T S +# ================================================================== +from __future__ import annotations +import warnings +warnings.filterwarnings("ignore") + +import os +import io +import sys +import math +import random +import collections +import collections.abc +import re +from itertools import repeat +from pathlib import Path +from typing import Optional, Tuple, Union, List, Dict + +import csv +import copy +import numpy as np +import pandas as pd +from PIL import Image +import seaborn as sns +import matplotlib.pyplot as plt +from tqdm import trange, tqdm + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out +import torch.utils.checkpoint as checkpoint + +import torchvision +from torchvision.transforms import v2 +from torch.utils.tensorboard import SummaryWriter +# from tensorboardX import SummaryWriter + +# os.environ["CUDA_VISIBLE_DEVICES"] = "1" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# print(f"Using device: {device}") + +import torchaudio +import torchaudio.transforms as T +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from transformers import AutoModel, AutoTokenizer, logging +from huggingface_hub.file_download import hf_hub_download +from huggingface_hub.file_download import hf_hub_download +from peft import get_peft_config, get_peft_model +from transformers import CLIPVisionModel, AutoProcessor + +# from watermark import watermark +# print(watermark( +# author='Ashish', +# # email='ashish@example.com', +# current_date=True, +# datename=True, +# current_time=True, +# iso8601=True, +# timezone=True, +# updated=True, +# custom_time=None, +# python=True, +# # packages="torch,torchvision,numpy", +# conda=True, +# hostname=True, +# machine=True, +# watermark=False, +# iversions=True, +# gpu=True, +# globals_=globals() +# )) + + +# ================================================================== +# H T S - A T +# ================================================================== +class HTSATConfig: + # Ke Chen + # knutchen@ucsd.edu + # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION + # The configuration for training the model + + exp_name = "exp_htsat_pretrain" # the saved ckpt prefix name of the model + workspace = "/home/kechen/Research/HTSAT" # the folder of your code + dataset_path = "/home/Research/audioset" # the dataset path + desed_folder = "/home/Research/DESED" # the desed file + + dataset_type = "audioset" # "audioset" "esc-50" "scv2" + index_type = "full_train" # only works for audioset + balanced_data = True # only works for audioset + + loss_type = "clip_bce" # + # AudioSet & SCV2: "clip_bce" | ESC-50: "clip_ce" + + # trained from a checkpoint, or evaluate a single model + resume_checkpoint = None + # "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt" + + esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation + + + debug = False + + random_seed = 970131 # 19970318 970131 12412 127777 1009 34047 + batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128 + learning_rate = 1e-3 # 1e-4 also workable + max_epoch = 100 + num_workers = 3 + + lr_scheduler_epoch = [10,20,30] + lr_rate = [0.02, 0.05, 0.1] + + # these data preparation optimizations do not bring many improvements, so deprecated + enable_token_label = False # token label + class_map_path = "class_hier_map.npy" + class_filter = None + retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762, + 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900] + token_label_range = [0.2,0.6] + enable_time_shift = False # shift time + enable_label_enhance = False # enhance hierarchical label + enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram + + + + # for model's design + enable_tscam = True # enbale the token-semantic layer + + # for signal processing + sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50 + clip_samples = sample_rate * 10 # audio_set 10-sec clip + window_size = 1024 + hop_size = 320 # 160 for scv2, 320 for audioset and esc-50 + mel_bins = 64 + fmin = 50 + fmax = 14000 + shift_max = int(clip_samples * 0.5) + + # for data collection + classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35 + patch_size = (25, 4) # deprecated + crop_size = None # int(clip_samples * 0.5) deprecated + + # for htsat hyperparamater + htsat_window_size = 8 + htsat_spec_size = 256 + htsat_patch_size = 4 + htsat_stride = (4, 4) + htsat_num_head = [4,8,16,32] + htsat_dim = 96 + htsat_depth = [2,2,6,2] + + swin_pretrain_path = None + # "/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth" + + # Some Deprecated Optimization in the model design, check the model code for details + htsat_attn_heatmap = False + htsat_hier_output = False + htsat_use_max = False + + + # for ensemble test + + ensemble_checkpoints = [] + ensemble_strides = [] + + + # weight average folder + wa_folder = "/home/version_0/checkpoints/" + # weight average output filename + wa_model_path = "HTSAT_AudioSet_Saved_x.ckpt" + + esm_model_pathes = [ + "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt", + "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt", + "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt", + "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt", + "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt", + "/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt" + ] + + # for framewise localization + heatmap_dir = "/home/Research/heatmap_output" + test_file = "htsat-test-ensemble" + fl_local = False # indicate if we need to use this dataset for the framewise detection + fl_dataset = "/home/Research/desed/desedim_embval.npy" + fl_class_num = [ + "Speech", "Frying", "Dishes", "Running_water", + "Blender", "Electric_shaver_toothbrush", "Alarm_bell_ringing", + "Cat", "Dog", "Vacuum_cleaner" + ] + + # map 527 classes into 10 classes + fl_audioset_mapping = [ + [0,1,2,3,4,5,6,7], + [366, 367, 368], + [364], + [288, 289, 290, 291, 292, 293, 294, 295, 296, 297], + [369], + [382], + [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402], + [81, 82, 83, 84, 85], + [74, 75, 76, 77, 78, 79], + [377] + ] + + + +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + +def do_mixup(x, mixup_lambda): + """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes + (1, 3, 5, ...). + Args: + x: (batch_size * 2, ...) + mixup_lambda: (batch_size * 2,) + Returns: + out: (batch_size, ...) + """ + out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \ + x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1) + return out + +def interpolate(x, ratio): + """Interpolate data in time domain. This is used to compensate the + resolution reduction in downsampling of a CNN. + + Args: + x: (batch_size, time_steps, classes_num) + ratio: int, ratio to interpolate + Returns: + upsampled: (batch_size, time_steps * ratio, classes_num) + """ + (batch_size, time_steps, classes_num) = x.shape + upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) + upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) + return upsampled + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patch_stride = to_2tuple(patch_stride) + self.img_size = img_size + self.patch_size = patch_size + self.patch_stride = patch_stride + self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.in_chans = in_chans + self.embed_dim = embed_dim + + padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2) + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def _no_gradim_audiorunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_gradim_audiorunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == 'fan_in': + denom = fan_in + elif mode == 'fan_out': + denom = fan_out + elif mode == 'fan_avg': + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal') + + +# below codes are based and referred from https://github.com/microsoft/Swin-Transformer +# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + def extra_repr(self): + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + +# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.norm_before_mlp = norm_before_mlp + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + if self.norm_before_mlp == 'ln': + self.norm2 = nn.LayerNorm(dim) + elif self.norm_before_mlp == 'bn': + self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2) + else: + raise NotImplementedError + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + # pdb.set_trace() + H, W = self.input_resolution + # print("H: ", H) + # print("W: ", W) + # pdb.set_trace() + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self): + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + norm_before_mlp='ln'): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, norm_before_mlp=norm_before_mlp) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + attns = [] + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x, attn = blk(x) + if not self.training: + attns.append(attn.unsqueeze(0)) + if self.downsample is not None: + x = self.downsample(x) + if not self.training: + attn = torch.cat(attns, dim = 0) + attn = torch.mean(attn, dim = 0) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +# The Core of HTSAT +class HTSAT_Swin_Transformer(nn.Module): + r"""HTSAT based on the Swin Transformer + Args: + spec_size (int | tuple(int)): Input Spectrogram size. Default 256 + patch_size (int | tuple(int)): Patch size. Default: 4 + path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 + in_chans (int): Number of input image channels. Default: 1 (mono) + num_classes (int): Number of classes for classification head. Default: 527 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 8 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + config (module): The configuration Module from config.py (HTSATConfig Class) + """ + + def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4), + in_chans=1, num_classes=527, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32], + window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, patch_norm=True, + use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs): + super(HTSAT_Swin_Transformer, self).__init__() + + self.config = config + self.spec_size = spec_size + self.patch_stride = patch_stride + self.patch_size = patch_size + self.window_size = window_size + self.embed_dim = embed_dim + self.depths = depths + self.ape = ape + self.in_chans = in_chans + self.num_classes = num_classes + self.num_heads = num_heads + self.num_layers = len(self.depths) + self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) + + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + + self.qkv_bias = qkv_bias + self.qk_scale = None + + self.patch_norm = patch_norm + self.norm_layer = norm_layer if self.patch_norm else None + self.norm_before_mlp = norm_before_mlp + self.mlp_ratio = mlp_ratio + + self.use_checkpoint = use_checkpoint + + # process mel-spec ; used only once + self.freq_ratio = self.spec_size // self.config.mel_bins + window = 'hann' + center = True + pad_mode = 'reflect' + ref = 1.0 + amin = 1e-10 + top_db = None + self.interpolate_ratio = 32 # Downsampled ratio + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size, + win_length=config.window_size, window=window, center=center, pad_mode=pad_mode, + freeze_parameters=True) + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size, + n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db, + freeze_parameters=True) + # Spec augmenter + self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, + freq_drop_width=8, freq_stripes_num=2) # 2 2 + self.bn0 = nn.BatchNorm2d(self.config.mel_bins) + + + # split spctrogram into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans, + embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride) + + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.grid_size + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=self.drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, qk_scale=self.qk_scale, + drop=self.drop_rate, attn_drop=self.attn_drop_rate, + drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])], + norm_layer=self.norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + norm_before_mlp=self.norm_before_mlp) + self.layers.append(layer) + + self.norm = self.norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.maxpool = nn.AdaptiveMaxPool1d(1) + + if self.config.enable_tscam: + SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio + self.tscam_conv = nn.Conv2d( + in_channels = self.num_features, + out_channels = self.num_classes, + kernel_size = (SF,3), + padding = (0,1) + ) + self.head = nn.Linear(num_classes, num_classes) + else: + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + frames_num = x.shape[2] + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + for i, layer in enumerate(self.layers): + x, attn = layer(x) + + if self.config.enable_tscam: + # for x + x = self.norm(x) + B, N, C = x.shape + SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] + ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] + x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST) + B, C, F, T = x.shape + # group 2D CNN + c_freq_bin = F // self.freq_ratio + x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) + x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1) + + # get latent_output + latent_output = self.avgpool(torch.flatten(x,2)) + latent_output = torch.flatten(latent_output, 1) + + # display the attention map, if needed + if self.config.htsat_attn_heatmap: + # for attn + attn = torch.mean(attn, dim = 1) + attn = torch.mean(attn, dim = 1) + attn = attn.reshape(B, SF, ST) + c_freq_bin = SF // self.freq_ratio + attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST) + attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1) + attn = attn.mean(dim = 1) + attn_max = torch.max(attn, dim = 1, keepdim = True)[0] + attn_min = torch.min(attn, dim = 1, keepdim = True)[0] + attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min) + attn = attn.unsqueeze(dim = 2) + + x = self.tscam_conv(x) + x = torch.flatten(x, 2) # B, C, T + + if self.config.htsat_attn_heatmap: + fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1]) + else: + fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + if self.config.loss_type == "clip_ce": + output_dict = { + 'framewise_output': fpx, # already sigmoided + 'clipwise_output': x, + 'latent_output': latent_output + } + else: + output_dict = { + 'framewise_output': fpx, # already sigmoided + 'clipwise_output': torch.sigmoid(x), + 'latent_output': latent_output + } + + else: + x = self.norm(x) # B N C + B, N, C = x.shape + + fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) ) + B, C, F, T = fpx.shape + c_freq_bin = F // self.freq_ratio + fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T) + fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1) + fpx = torch.sum(fpx, dim = 2) + fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) + x = self.avgpool(x.transpose(1, 2)) # B C 1 + x = torch.flatten(x, 1) + if self.num_classes > 0: + x = self.head(x) + fpx = self.head(fpx) + output_dict = {'framewise_output': torch.sigmoid(fpx), + 'clipwise_output': torch.sigmoid(x)} + return output_dict + + def crop_wav(self, x, crop_size, spe_pos = None): + time_steps = x.shape[2] + tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) + for i in range(len(x)): + if spe_pos is None: + crop_pos = random.randint(0, time_steps - crop_size - 1) + else: + crop_pos = spe_pos + tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:] + return tx + + # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model + def reshape_wav2img(self, x): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True) + if F < target_F: + x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) + x = x.permute(0,1,3,2).contiguous() + x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio) + # print(x.shape) + x = x.permute(0,1,3,2,4).contiguous() + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) + return x + + # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model + def repeat_wat2img(self, x, cur_pos): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert T <= target_T and F <= target_F, "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode="bicubic", align_corners=True) + if F < target_F: + x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) + x = x.permute(0,1,3,2).contiguous() # B C F T + x = x[:,:,:,cur_pos:cur_pos + self.spec_size] + x = x.repeat(repeats = (1,1,4,1)) + return x + + def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None): + x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.training: + x = self.spec_augmenter(x) + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + if infer_mode: + # in infer mode. we need to handle different length audio input + frame_num = x.shape[2] + target_T = int(self.spec_size * self.freq_ratio) + repeat_ratio = math.floor(target_T / frame_num) + x = x.repeat(repeats=(1,1,repeat_ratio,1)) + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x) + elif self.config.enable_repeat_mode: + if self.training: + cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1) + x = self.repeat_wat2img(x, cur_pos) + output_dict = self.forward_features(x) + else: + output_dicts = [] + for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size): + tx = x.clone() + tx = self.repeat_wat2img(tx, cur_pos) + output_dicts.append(self.forward_features(tx)) + clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) + framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) + for d in output_dicts: + clipwise_output += d["clipwise_output"] + framewise_output += d["framewise_output"] + clipwise_output = clipwise_output / len(output_dicts) + framewise_output = framewise_output / len(output_dicts) + + output_dict = { + 'framewise_output': framewise_output, + 'clipwise_output': clipwise_output + } + else: + if x.shape[2] > self.freq_ratio * self.spec_size: + if self.training: + x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x) + else: + # Change: Hard code here + overlap_size = 344 #(x.shape[2] - 1) // 4 + output_dicts = [] + crop_size = 689 #(x.shape[2] - 1) // 2 + for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): + tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) + tx = self.reshape_wav2img(tx) + output_dicts.append(self.forward_features(tx)) + clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) + framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) + latent_output = torch.zeros_like(output_dicts[0]["latent_output"]).float().to(x.device) + for d in output_dicts: + clipwise_output += d["clipwise_output"] + framewise_output += d["framewise_output"] + latent_output += d["latent_output"] + clipwise_output = clipwise_output / len(output_dicts) + framewise_output = framewise_output / len(output_dicts) + latent_output = latent_output / len(output_dicts) + output_dict = { + 'framewise_output': framewise_output, + 'clipwise_output': clipwise_output, + 'latent_output': latent_output, + } + else: # this part is typically used, and most easy one + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x) + # x = self.head(x) + return output_dict + +class HTSATWrapper(nn.Module): + def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, + fmax, classes_num, out_emb): + super().__init__() + + # print("parameters are being overidden when using HTSAT") + # print("HTSAT only support loading a pretrained model on AudioSet") + # @TODO later look at what parameters are same and can be merged + + self.htsat = HTSAT_Swin_Transformer(config=HTSATConfig()) + + def forward(self, x): + out_dict = self.htsat(x) + out_dict['embedding'] = out_dict['latent_output'] + return out_dict + + +def get_audio_encoder(name: str): + if name == "HTSAT": + return HTSATWrapper + else: + raise Exception('The audio encoder name {} is incorrect or not supported'.format(name)) + +class Projection(nn.Module): + def __init__(self, dim_imgn: int, d_out: int, p: float=0.5) -> None: + super().__init__() + self.linear1 = nn.Linear(dim_imgn, d_out, bias=False) + self.linear2 = nn.Linear(d_out, d_out, bias=False) + self.layer_norm = nn.LayerNorm(d_out) + self.drop = nn.Dropout(p) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + embed1 = self.linear1(x) + embed2 = self.drop(self.linear2(F.gelu(embed1))) + embeds = self.layer_norm(embed1 + embed2) + return embeds + +class AudioEncoder(nn.Module): + def __init__(self, audioenc_name:str, dim_imgn: int, d_out: int, sample_rate: int, window_size: int, + hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None: + super().__init__() + + audio_encoder = get_audio_encoder(audioenc_name) + + self.base = audio_encoder( + sample_rate, window_size, + hop_size, mel_bins, fmin, fmax, + classes_num, dim_imgn) + + self.projection = Projection(dim_imgn, d_out) + + def forward(self, x): + out_dict = self.base(x) + audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output'] + projected_vec = self.projection(audio_features) + return projected_vec, audio_classification_output + +class TextEncoder(nn.Module): + def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None: + super().__init__() + self.text_model = text_model + self.base = AutoModel.from_pretrained(text_model) + + if 'clip' in text_model: + self.clip_text_projection = self.base.text_projection + self.base = self.base.text_model + if 'base' in text_model: + transformer_embed_dim = 512 + + self.projection = Projection(transformer_embed_dim, d_out) + + def forward(self, x): + if 'clip' in self.text_model: + pooled_output = self.base(**x)[1] # get pooled output + out = self.clip_text_projection(pooled_output) # get CLS token output + elif 'gpt' in self.text_model: + batch_size = x['input_ids'].shape[0] + hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768) + + sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17]) + out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768] + else: + out = self.base(**x)[0] + out = out[:, 0, :] # get CLS token output + + projected_vec = self.projection(out) + + return projected_vec + +class CLAP(nn.Module): + def __init__(self, + # audio + audioenc_name: str, + sample_rate: int, + window_size: int, + hop_size: int, + mel_bins: int, + fmin: int, + fmax: int, + classes_num: int, + out_emb: int, + # text + text_model: str, + transformer_embed_dim: int, + # common + d_proj: int, + ): + super().__init__() + + + self.audio_encoder = AudioEncoder( + audioenc_name, out_emb, d_proj, + sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num) + + self.caption_encoder = TextEncoder( + d_proj, text_model, transformer_embed_dim + ) + + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def forward(self, audio, text): + audio_embed, _ = self.audio_encoder(audio) + caption_embed = self.caption_encoder(text) + + return caption_embed, audio_embed, self.logit_scale.exp() + + + +# ================================================================== +# A U D I O - P R E - P R O C E S S I N G +# ================================================================== +def read_audio(audio_path, resample=True): + r"""Loads audio file or array and returns a torch tensor""" + # Randomly sample a segment of audio_duration from the clip or pad to match duration + audio_time_series, sample_rate = torchaudio.load(audio_path) + + resample_rate = clapConfig.sample_rate + if resample and resample_rate != sample_rate: + resampler = T.Resample(sample_rate, resample_rate) + audio_time_series = resampler(audio_time_series) + return audio_time_series, resample_rate + +def load_audio_into_tensor(audio_path, audio_duration, resample=False): + r"""Loads audio file and returns raw audio.""" + # Randomly sample a segment of audio_duration from the clip or pad to match duration + audio_time_series, sample_rate = read_audio(audio_path, resample) + audio_time_series = audio_time_series.reshape(-1) + + # audio_time_series is shorter than predefined audio duration, + # so audio_time_series is extended + if audio_duration*sample_rate >= audio_time_series.shape[0]: + repeat_factor = int(np.ceil((audio_duration*sample_rate) / + audio_time_series.shape[0])) + # Repeat audio_time_series by repeat_factor to match audio_duration + audio_time_series = audio_time_series.repeat(repeat_factor) + # remove excess part of audio_time_series + audio_time_series = audio_time_series[0:audio_duration*sample_rate] + else: + # audio_time_series is longer than predefined audio duration, + # so audio_time_series is trimmed + start_index = random.randrange( + audio_time_series.shape[0] - audio_duration*sample_rate) + audio_time_series = audio_time_series[start_index:start_index + + audio_duration*sample_rate] + return torch.FloatTensor(audio_time_series) + +np_str_obj_array_pattern = re.compile(r'[SaUO]') +default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}") + +def default_collate(batch): + r"""Puts each data field into a tensor with outer dimension batch size""" + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError( + default_collate_err_msg_format.format(elem.dtype)) + + return default_collate([torch.as_tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return torch.as_tensor(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float64) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, str): + return batch + elif isinstance(elem, collections.abc.Mapping): + return {key: default_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(default_collate(samples) for samples in zip(*batch))) + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError( + 'each element in list of batch should be of equal size') + transposed = zip(*batch) + return [default_collate(samples) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) + +def preprocess_audio(audio_files, resample): + r"""Load list of audio files and return raw audio""" + audio_tensors = [] + for audio_file in audio_files: + audio_tensor = load_audio_into_tensor( + audio_file, clapConfig.duration, resample) + audio_tensor = audio_tensor.reshape(1, -1) + audio_tensors.append(audio_tensor) + return default_collate(audio_tensors) + + + +# ================================================================== +# A U D I O - E M B E D D I N G S - H E L P E R +# ================================================================== +def CLAPAudioProcessor(audio_files: List[str], resample=True): + preprocessed_audio = preprocess_audio(audio_files, resample) + preprocessed_audio = preprocessed_audio.reshape( + preprocessed_audio.shape[0], preprocessed_audio.shape[2]) + preprocessed_audio = preprocessed_audio + return preprocessed_audio + +def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True): + """Load list of audio files and return audio embeddings""" + # preprocessed_audio = preprocess_audio(audio_files, resample) + # with torch.no_grad(): + # preprocessed_audio = preprocessed_audio.reshape( + # preprocessed_audio.shape[0], preprocessed_audio.shape[2]) + with torch.no_grad(): + preprocessed_audio = CLAPAudioProcessor(audio_files, resample) + return audio_encoder(preprocessed_audio)[0] + + +# ================================================================== +# C L A P +# ================================================================== +class ClapConfig: + # TEXT ENCODER CONFIG + text_model = 'gpt2' + text_len = 77 + transformer_embed_dim = 768 + freeze_text_encoder_weights = True + + # AUDIO ENCODER CONFIG + audioenc_name = 'HTSAT' + out_emb = 768 + sample_rate = 44100 + duration = 7 + fmin = 50 + fmax = 8000 # 14000 + n_fft = 1024 # 1028 + hop_size = 320 + mel_bins = 64 + window_size = 1024 + + # PROJECTION SPACE CONFIG + d_proj = 1024 + temperature = 0.003 + + # TRAINING AND EVALUATION CONFIG + num_classes = 527 + batch_size = 1024 + demo = False + + +clapConfig = ClapConfig() +clap = CLAP( + audioenc_name=clapConfig.audioenc_name, + sample_rate=clapConfig.sample_rate, + window_size=clapConfig.window_size, + hop_size=clapConfig.hop_size, + mel_bins=clapConfig.mel_bins, + fmin=clapConfig.fmin, + fmax=clapConfig.fmax, + classes_num=clapConfig.num_classes, + out_emb=clapConfig.out_emb, + text_model=clapConfig.text_model, + transformer_embed_dim=clapConfig.transformer_embed_dim, + d_proj=clapConfig.d_proj + ) + +model_repo = "microsoft/msclap" +model_name = { + '2022': 'CLAP_weights_2022.pth', + '2023': 'CLAP_weights_2023.pth', + 'clapcap': 'clapcap_weights_2023.pth' +} + +version = '2023' +model_fp = hf_hub_download(model_repo, model_name[version]) + +model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model'] +clap.load_state_dict(model_state_dict, strict=False) +# clap.eval() + +clap_audio_encoder = clap.audio_encoder.to(device) + +# ENGLISH_AUDIO_DIR = r"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English" +# audio_files = [os.path.join(ENGLISH_AUDIO_DIR, i) for i in os.listdir(ENGLISH_AUDIO_DIR) if i.endswith(".wav")] +# audio_embedding = get_audio_embeddings(audio_files, clap_audio_encoder) +# print("CLAP Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024] + + +# ================================================================== +# C L A P - L o R A - M O D E L +# ================================================================== +LoRAconfig = { + "peft_type": "LORA", + "task_type": "FEATURE_EXTRACTION", + "inference_mode": False, + "r": 16, + "target_modules": ["qkv", "fc1", "fc2", "proj", "linear1", "linear2"], + "lora_alpha": 32, + "lora_dropout": 0.05, + "fan_in_fan_out": False, + "bias": "all", +} +peft_config = get_peft_config(LoRAconfig) + +peft_model = get_peft_model(clap_audio_encoder, peft_config) + +# peft_model.print_trainable_parameters() + +peft_clap_audio_encoder = peft_model.base_model +# audio_embedding = get_audio_embeddings(audio_files, peft_clap_audio_encoder) +# print("CLAP LoRA Audio Encoder Embeddings:", audio_embedding.shape) # [5, 1024] + + + +# ================================================================== +# O P E N - C L I P - M O D E L +# ================================================================== +# ================================================================== +# I M P O R T S +# ================================================================== + + +import os +import io +import sys +import math +import random +import collections +import collections.abc +import re +from itertools import repeat +from pathlib import Path +from typing import Optional, Tuple, Union, List, Dict + +import csv +import numpy as np +import pandas as pd +from PIL import Image +import seaborn as sns +import matplotlib.pyplot as plt +from tqdm import trange, tqdm + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out +import torch.utils.checkpoint as checkpoint + +import torchvision +from torchvision.transforms import v2 + +# os.environ["CUDA_VISIBLE_DEVICES"] = "1" +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# print(f"Using device: {device}") + +import torchaudio +import torchaudio.transforms as T +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from transformers import AutoModel, AutoTokenizer, logging +from huggingface_hub.file_download import hf_hub_download +from huggingface_hub.file_download import hf_hub_download +from peft import get_peft_config, get_peft_model + +from typing import Any, Dict, Optional, Tuple, Union +import numbers +import random +import warnings +from dataclasses import dataclass, asdict +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import torch +import torchvision.transforms.functional as F +from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ + CenterCrop, ColorJitter, Grayscale + +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) +INCEPTION_MEAN = (0.5, 0.5, 0.5) +INCEPTION_STD = (0.5, 0.5, 0.5) + +# Default name for a weights file hosted on the Huggingface Hub. +HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl +HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version +HF_CONFIG_NAME = 'open_clip_config.json' + + +import collections.abc +from itertools import repeat +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn as nn +from torch import _assert +from torchvision.ops.misc import FrozenBatchNorm2d + + +def freeze_batch_norm_2d(module, module_match={}, name=''): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = '.'.join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = lambda n, x: _ntuple(n)(x) + +# Replaces all linear layers with linear_replacement +# TODO: add int8 support for other linear layers including attn and convnets +def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, include_modules, copy_weights) + + if isinstance(module, torch.nn.Linear) and name in include_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight.data.copy_(old_module.weight.data) + if model._modules[name].bias is not None: + model._modules[name].bias.data.copy_(old_module.bias) + + return model + +def convert_int8_model_to_inference_mode(model): + for m in model.modules(): + if hasattr(m, 'prepare_for_eval'): + int8_original_dtype = m.weight.dtype + m.prepare_for_eval() + m.int8_original_dtype = int8_original_dtype + + +def feature_take_indices( + num_features: int, + indices: Optional[Union[int, List[int]]] = None, + as_set: bool = False, +) -> Tuple[List[int], int]: + """ Determine the absolute feature indices to 'take' from. + + Note: This function can be called in forward() so must be torchscript compatible, + which requires some incomplete typing and workaround hacks. + + Args: + num_features: total number of features to select from + indices: indices to select, + None -> select all + int -> select last n + list/tuple of int -> return specified (-ve indices specify from end) + as_set: return as a set + + Returns: + List (or set) of absolute (from beginning) indices, Maximum index + """ + if indices is None: + indices = num_features # all features if None + + if isinstance(indices, int): + # convert int -> last n indices + _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})') + take_indices = [num_features - indices + i for i in range(indices)] + else: + take_indices: List[int] = [] + for i in indices: + idx = num_features + i if i < 0 else i + _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})') + take_indices.append(idx) + + if not torch.jit.is_scripting() and as_set: + return set(take_indices), max(take_indices) + + return take_indices, max(take_indices) + + +def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]: + if isinstance(x, int): + # if indices is an int, take last N features + return tuple(range(-x, 0)) + return tuple(x) + + + +import copy +import copy +import hashlib +import os +import urllib +import warnings +from functools import partial +from typing import Dict, Iterable, Optional, Union + +from tqdm import tqdm + + +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False + +__version__ = '2.32.0' + + +""" CLIP Model + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import copy +import logging +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint +from functools import partial + +# from .hf_model import HFTextEncoder +# from .modified_resnet import ModifiedResNet +from collections import OrderedDict +import math +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + +# from .utils import to_2tuple, feature_take_indices +# from .pos_embed import get_2d_sincos_pos_embed +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + + +from collections import OrderedDict +from typing import Dict, List, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +# from .utils import freeze_batch_norm_2d, feature_take_indices + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.act1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.act2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.act3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.act1(self.bn1(self.conv1(x))) + out = self.act2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.act3(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0., + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs antialiasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__( + self, + layers: List[int], + output_dim: int, + heads: int, + image_size: int = 224, + width: int = 64, + ): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.act1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.act2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.act3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert unlocked_groups == 0, 'partial locking not currently supported for this model' + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + + def stem(self, x): + x = self.act1(self.bn1(self.conv1(x))) + x = self.act2(self.bn2(self.conv2(x))) + x = self.act3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + stop_early: bool = False, + normalize_intermediates: bool = False, + intermediates_only: bool = False, + output_fmt: str = 'NCHW', + output_extra_tokens: bool = False, + ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + stop_early: Stop iterating over blocks when last desired intermediate hit + normalize_intermediates: Apply final norm layer to all intermediates + intermediates_only: Only return intermediate features + output_fmt: Shape of intermediate feature outputs + output_extra_tokens: Return both extra class, eot tokens + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output format must be == NCHW.' + # NOTE normalize_intermediates and return_extra_tokens don't apply + take_indices, max_index = feature_take_indices(5, indices) + + output = {} + intermediates = [] + blocks = [self.stem, self.layer1, self.layer2, self.layer3, self.layer4] + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + intermediates.append(x) + + output['image_intermediates'] = intermediates + + if intermediates_only: + return output + + x = self.attnpool(x) + output['image_features'] = x + + return output + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +""" huggingface model adapter + +Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. +""" +import re + +import torch +import torch.nn as nn +from torch import TensorType + +try: + import transformers + from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig + from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ + BaseModelOutputWithPoolingAndCrossAttentions +except ImportError as e: + transformers = None + + + class BaseModelOutput: + pass + + + class PretrainedConfig: + pass + +# from .hf_configs import arch_dict +# HF architecture dict: +arch_dict = { + # https://huggingface.co/docs/transformers/model_doc/roberta#roberta + "roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig + "xlm-roberta": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + "layer_attr": "layer", + "token_embeddings_attr": "embeddings" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 + "mt5": { + "config_names": { + # unlimited seqlen + # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 + "context_length": "", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "num_heads", + "layers": "num_layers", + "layer_attr": "block", + "token_embeddings_attr": "embed_tokens" + }, + "pooler": "mean_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/bert + "bert": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "hidden_size", + "heads": "num_attention_heads", + "layers": "num_hidden_layers", + }, + "pooler": "cls_pooler", + }, + # https://huggingface.co/docs/transformers/model_doc/m2m_100 + "m2m_100": { + "config_names": { + "context_length": "max_position_embeddings", + "vocab_size": "vocab_size", + "width": "d_model", + "heads": "encoder_attention_heads", + "layers": "encoder_layers", + }, + "pooler": "cls_pooler", + }, +} + + + +# utils +def _camel2snake(s): + return re.sub(r'(? Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + stop_early: Stop iterating over blocks when last desired intermediate hit + normalize_intermediates: Apply norm layer to all intermediates + intermediates_only: Only return intermediate features + output_fmt: Shape of intermediate feature outputs + output_extra_tokens: Return both prefix and spatial intermediate tokens + Returns: + """ + extra_args = {} + if output_extra_tokens: + extra_args['return_prefix_tokens'] = True + trunk_output = self.trunk.forward_intermediates( + x, + indices=indices, + intermediates_only=intermediates_only, + norm=normalize_intermediates, + stop_early=stop_early, + output_fmt=output_fmt, + **extra_args, + ) + + return_dict = {} + intermediates = trunk_output if intermediates_only else trunk_output[1] + if output_extra_tokens and intermediates and isinstance(intermediates[0], tuple): + intermediates_prefix = [xi[1] for xi in intermediates] + intermediates = [xi[0] for xi in intermediates] + return_dict['image_intermediates_prefix'] = intermediates_prefix + + return_dict['image_intermediates'] = intermediates + if intermediates_only: + return return_dict + + image_features = self.trunk.forward_head(trunk_output[0]) # run through timm pooling / projection + image_features = self.head(image_features) # run through adapter pooling / projection + return_dict['image_features'] = image_features + return return_dict + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x + + +class LayerNormFp32(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm (with cast back to input dtype).""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchDropout(nn.Module): + """ + https://arxiv.org/abs/2212.00794 + """ + + def __init__( + self, + prob: float = 0.5, + exclude_first_token: bool = True + ): + super().__init__() + assert 0 <= prob < 1. + self.prob = prob + self.exclude_first_token = exclude_first_token # exclude CLS token + + def forward(self, x): + if not self.training or self.prob == 0.: + return x + + if self.exclude_first_token: + cls_tokens, x = x[:, :1], x[:, 1:] + else: + cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) + + batch = x.size()[0] + num_tokens = x.size()[1] + + batch_indices = torch.arange(batch) + batch_indices = batch_indices[..., None] + + keep_prob = 1 - self.prob + num_patches_keep = max(1, int(num_tokens * keep_prob)) + + rand = torch.randn(batch, num_tokens) + patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices + + x = x[batch_indices, patch_indices_keep] + + if self.exclude_first_token: + x = torch.cat((cls_tokens, x), dim=1) + + return x + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + scaled_cosine: bool = False, + scale_heads: bool = False, + logit_scale_max: float = math.log(1. / 0.01), + batch_first: bool = True, + attn_drop: float = 0., + proj_drop: float = 0. + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.logit_scale_max = logit_scale_max + self.batch_first = batch_first + self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention') + + # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original + self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) + if qkv_bias: + self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) + else: + self.in_proj_bias = None + + if self.scaled_cosine: + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + else: + self.logit_scale = None + self.attn_drop = nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = nn.Linear(dim, dim) + self.out_drop = nn.Dropout(proj_drop) + + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + if self.batch_first: + x = x.transpose(0, 1) + + L, N, C = x.shape + q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) + q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1) + k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1) + v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1) + + if attn_mask is not None and attn_mask.dtype == torch.bool: + new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) + new_attn_mask.masked_fill_(attn_mask, float("-inf")) + attn_mask = new_attn_mask + + if self.logit_scale is not None: + attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) + logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() + attn = attn.view(N, self.num_heads, L, L) * logit_scale + attn = attn.view(-1, L, L) + if attn_mask is not None: + attn = attn + attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = torch.bmm(attn, v) + else: + if self.use_fsdpa: + x = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = torch.bmm(q, k.transpose(-1, -2)) + if attn_mask is not None: + attn += attn_mask + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = torch.bmm(attn, v) + + if self.head_scale is not None: + x = x.view(N, self.num_heads, L, C) * self.head_scale + x = x.view(-1, L, C) + + x = x.transpose(0, 1).reshape(L, N, C) + + if self.batch_first: + x = x.transpose(0, 1) + + x = self.out_proj(x) + x = self.out_drop(x) + return x + + +class AttentionalPooler(nn.Module): + def __init__( + self, + d_model: int, + context_dim: int, + n_head: int = 8, + n_queries: int = 256, + norm_layer: Callable = LayerNorm, + ): + super().__init__() + self.query = nn.Parameter(torch.randn(n_queries, d_model)) + self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True) + self.ln_q = norm_layer(d_model) + self.ln_k = norm_layer(context_dim) + + def forward(self, x: torch.Tensor): + N = x.shape[0] + x = self.ln_k(x) + q = self.ln_q(self.query) + out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0] + return out + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + is_cross_attention: bool = False, + batch_first: bool = True, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first) + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn( + q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask + )[0] + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class CustomResidualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + scale_cosine_attn: bool = False, + scale_heads: bool = False, + scale_attn: bool = False, + scale_fc: bool = False, + batch_first: bool = True, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.attn = Attention( + d_model, + n_head, + scaled_cosine=scale_cosine_attn, + scale_heads=scale_heads, + batch_first=batch_first, + ) + self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() + self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() + + def get_reference_weight(self): + return self.mlp.c_fc.weight + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) + x = x + self.ls_2(self.mlp(self.ln_2(x))) + return x + + +class CustomTransformer(nn.Module): + """ A custom transformer that can use different block types. """ + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + batch_first: bool = True, + block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock', + ): + super().__init__() + self.width = width + self.layers = layers + self.batch_first = batch_first # run transformer stack in batch first (N, L, D) + self.grad_checkpointing = False + + if isinstance(block_types, str): + block_types = [block_types] * layers + assert len(block_types) == layers + + def _create_block(bt: str): + if bt == 'CustomResidualAttentionBlock': + return CustomResidualAttentionBlock( + width, + heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + batch_first=batch_first, + ) + else: + assert False + + self.resblocks = nn.ModuleList([ + _create_block(bt) + for bt in block_types + ]) + + def get_cast_dtype(self) -> torch.dtype: + weight = self.resblocks[0].get_reference_weight() + if hasattr(weight, 'int8_original_dtype'): + return weight.int8_original_dtype + return weight.dtype + + def forward_intermediates( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + indices: Optional[Union[int, List[int]]] = None, + stop_early: bool = False, + ): + take_indices, max_index = feature_take_indices(len(self.resblocks), indices) + + if not self.batch_first: + x = x.transpose(0, 1).contiguous() # NLD -> LND + + intermediates = [] + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.resblocks + else: + blocks = self.resblocks[:max_index + 1] + for i, blk in enumerate(blocks): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False) + else: + x = blk(x, attn_mask=attn_mask) + + if i in take_indices: + intermediates.append(x.transpose(0, 1) if not self.batch_first else x) + + if not self.batch_first: + x = x.transpose(0, 1) # LND -> NLD + + return x, intermediates + + def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.resblocks), indices) + self.resblocks = self.resblocks[:max_index + 1] # truncate blocks + return take_indices + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + if not self.batch_first: + x = x.transpose(0, 1) # NLD -> LND + + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) + else: + x = r(x, attn_mask=attn_mask) + + if not self.batch_first: + x = x.transpose(0, 1) # NLD -> LND + return x + + +class Transformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + batch_first: bool = True, + ): + super().__init__() + self.width = width + self.layers = layers + self.batch_first = batch_first + self.grad_checkpointing = False + + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + batch_first=batch_first, + ) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): + return self.resblocks[0].mlp.c_fc.int8_original_dtype + return self.resblocks[0].mlp.c_fc.weight.dtype + + def forward_intermediates( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + indices: Optional[Union[int, List[int]]] = None, + stop_early: bool = False, + ): + take_indices, max_index = feature_take_indices(len(self.resblocks), indices) + + if not self.batch_first: + x = x.transpose(0, 1).contiguous() # NLD -> LND + + intermediates = [] + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.resblocks + else: + blocks = self.resblocks[:max_index + 1] + for i, blk in enumerate(blocks): + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False) + else: + x = blk(x, attn_mask=attn_mask) + + if i in take_indices: + intermediates.append(x.transpose(0, 1) if not self.batch_first else x) + + if not self.batch_first: + x = x.transpose(0, 1) # LND -> NLD + + return x, intermediates + + def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.resblocks), indices) + self.resblocks = self.resblocks[:max_index + 1] # truncate blocks + return take_indices + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + if not self.batch_first: + x = x.transpose(0, 1).contiguous() # NLD -> LND + + for r in self.resblocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False) + else: + x = r(x, attn_mask=attn_mask) + + if not self.batch_first: + x = x.transpose(0, 1) # LND -> NLD + return x + + +def _expand_token(token, batch_size: int): + return token.view(1, 1, -1).expand(batch_size, -1, -1) + + +class VisionTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + ls_init_value: float = None, + attentional_pool: bool = False, + attn_pooler_queries: int = 256, + attn_pooler_heads: int = 8, + output_dim: int = 512, + patch_dropout: float = 0., + no_ln_pre: bool = False, + pos_embed_type: str = 'learnable', + pool_type: str = 'tok', + final_ln_after_pool: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False, + ): + super().__init__() + assert pool_type in ('tok', 'avg', 'none') + self.output_tokens = output_tokens + image_height, image_width = self.image_size = to_2tuple(image_size) + patch_height, patch_width = self.patch_size = to_2tuple(patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled + self.output_dim = output_dim + + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + if pos_embed_type == 'learnable': + self.positional_embedding = nn.Parameter( + scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) + elif pos_embed_type == 'sin_cos_2d': + # fixed sin-cos embedding + assert self.grid_size[0] == self.grid_size[1],\ + 'currently sin cos 2d pos embedding only supports square input' + self.positional_embedding = nn.Parameter( + torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) + pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) + self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float()) + else: + raise ValueError + + # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn + self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() + + self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) + self.transformer = Transformer( + width, + layers, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + if attentional_pool: + if isinstance(attentional_pool, str): + self.attn_pool_type = attentional_pool + self.pool_type = 'none' + if attentional_pool in ('parallel', 'cascade'): + self.attn_pool = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=attn_pooler_queries, + ) + self.attn_pool_contrastive = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=1, + ) + else: + assert False + else: + self.attn_pool_type = '' + self.pool_type = pool_type + self.attn_pool = AttentionalPooler( + output_dim, + width, + n_head=attn_pooler_heads, + n_queries=attn_pooler_queries, + ) + self.attn_pool_contrastive = None + pool_dim = output_dim + else: + self.attn_pool = None + pool_dim = width + self.pool_type = pool_type + + self.ln_post = norm_layer(pool_dim) + self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) + + self.init_parameters() + + def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False): + for param in self.parameters(): + param.requires_grad = False + + if unlocked_groups != 0: + groups = [ + [ + self.conv1, + self.class_embedding, + self.positional_embedding, + self.ln_pre, + ], + *self.transformer.resblocks[:-1], + [ + self.transformer.resblocks[-1], + self.ln_post, + ], + self.proj, + ] + + def _unlock(x): + if isinstance(x, Sequence): + for g in x: + _unlock(g) + else: + if isinstance(x, torch.nn.Parameter): + x.requires_grad = True + else: + for p in x.parameters(): + p.requires_grad = True + + _unlock(groups[-unlocked_groups:]) + + def init_parameters(self): + # FIXME OpenAI CLIP did not define an init for the VisualTransformer + # TODO experiment if default PyTorch init, below, or alternate init is best. + + # nn.init.normal_(self.class_embedding, std=self.scale) + # nn.init.normal_(self.positional_embedding, std=self.scale) + # + # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + # attn_std = self.transformer.width ** -0.5 + # fc_std = (2 * self.transformer.width) ** -0.5 + # for block in self.transformer.resblocks: + # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + # + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=self.scale) + pass + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default + no_wd = {'positional_embedding', 'class_embedding'} + return no_wd + + def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.pool_type == 'avg': + pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] + elif self.pool_type == 'tok': + pooled, tokens = x[:, 0], x[:, 1:] + else: + pooled = tokens = x + + return pooled, tokens + + def _embeds(self, x:torch.Tensor) -> torch.Tensor: + x = self.conv1(x) # shape = [*, dim, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) + # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + + # patch dropout (if active) + x = self.patch_dropout(x) + + # apply norm before transformer + x = self.ln_pre(x) + return x + + def _pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.attn_pool is not None: + if self.attn_pool_contrastive is not None: + # This is untested, WIP pooling that should match paper + x = self.ln_post(x) # TBD LN first or separate one after each pool? + tokens = self.attn_pool(x) + if self.attn_pool_type == 'parallel': + pooled = self.attn_pool_contrastive(x) + else: + assert self.attn_pool_type == 'cascade' + pooled = self.attn_pool_contrastive(tokens) + else: + # this is the original OpenCLIP CoCa setup, does not match paper + x = self.attn_pool(x) + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + elif self.final_ln_after_pool: + pooled, tokens = self._global_pool(x) + pooled = self.ln_post(pooled) + else: + x = self.ln_post(x) + pooled, tokens = self._global_pool(x) + + return pooled, tokens + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + stop_early: bool = False, + normalize_intermediates: bool = False, + intermediates_only: bool = False, + output_fmt: str = 'NCHW', + output_extra_tokens: bool = False, + ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + stop_early: Stop iterating over blocks when last desired intermediate hit + intermediates_only: Only return intermediate features + normalize_intermediates: Apply final norm layer to all intermediates + output_fmt: Shape of intermediate feature outputs + output_extra_tokens: Return both extra prefix class tokens + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + + # forward pass + B, _, height, width = x.shape + x = self._embeds(x) + x, intermediates = self.transformer.forward_intermediates( + x, + indices=indices, + stop_early=stop_early, + ) + + # process intermediates + if normalize_intermediates: + # apply final norm to all intermediates + intermediates = [self.ln_post(xi) for xi in intermediates] + num_prefix_tokens = 1 # one class token that's always there (as of now) + if num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:num_prefix_tokens] for y in intermediates] + intermediates = [y[:, num_prefix_tokens:] for y in intermediates] + else: + prefix_tokens = None + if reshape: + # reshape to BCHW output format + H, W = height // self.patch_size[0], width // self.patch_size[1] + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + + output = {'image_intermediates': intermediates} + if prefix_tokens is not None and output_extra_tokens: + output['image_intermediates_prefix'] = prefix_tokens + + if intermediates_only: + return output + + pooled, _ = self._pool(x) + + if self.proj is not None: + pooled = pooled @ self.proj + + output['image_features'] = pooled + + return output + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices = self.transformer.prune_intermediate_layers(indices) + if prune_norm: + self.ln_post = nn.Identity() + if prune_head: + self.proj = None + return take_indices + + def forward(self, x: torch.Tensor): + x = self._embeds(x) + x = self.transformer(x) + pooled, tokens = self._pool(x) + + if self.proj is not None: + pooled = pooled @ self.proj + + if self.output_tokens: + return pooled, tokens + + return pooled + + +def text_global_pool( + x: torch.Tensor, + text: Optional[torch.Tensor] = None, + pool_type: str = 'argmax', +) -> torch.Tensor: + if pool_type == 'first': + pooled = x[:, 0] + elif pool_type == 'last': + pooled = x[:, -1] + elif pool_type == 'argmax': + # take features from the eot embedding (eot_token is the highest number in each sequence) + assert text is not None + pooled = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] + else: + pooled = x + + return pooled + + +class TextTransformer(nn.Module): + output_tokens: torch.jit.Final[bool] + + def __init__( + self, + context_length: int = 77, + vocab_size: int = 49408, + width: int = 512, + heads: int = 8, + layers: int = 12, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + output_dim: Optional[int] = 512, + embed_cls: bool = False, + no_causal_mask: bool = False, + pad_id: int = 0, + pool_type: str = 'argmax', + proj_type: str = 'linear', + proj_bias: bool = False, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_tokens: bool = False, + ): + super().__init__() + assert pool_type in ('first', 'last', 'argmax', 'none') + self.output_tokens = output_tokens + self.num_pos = self.context_length = context_length + self.vocab_size = vocab_size + self.width = width + self.output_dim = output_dim + self.heads = heads + self.pad_id = pad_id + self.pool_type = pool_type + + self.token_embedding = nn.Embedding(vocab_size, width) + if embed_cls: + self.cls_emb = nn.Parameter(torch.empty(width)) + self.num_pos += 1 + else: + self.cls_emb = None + self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) + self.transformer = Transformer( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + ) + self.ln_final = norm_layer(width) + + if no_causal_mask: + self.attn_mask = None + else: + self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) + + if proj_type == 'none' or not output_dim: + self.text_projection = None + else: + if proj_bias: + self.text_projection = nn.Linear(width, output_dim) + else: + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + self.init_parameters() + + def init_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + if self.cls_emb is not None: + nn.init.normal_(self.cls_emb, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) + if self.text_projection.bias is not None: + nn.init.zeros_(self.text_projection.bias) + else: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default + no_wd = {'positional_embedding'} + if self.cls_emb is not None: + no_wd.add('cls_emb') + return no_wd + + def build_causal_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.num_pos, self.num_pos) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def build_cls_mask(self, text, cast_dtype: torch.dtype): + cls_mask = (text != self.pad_id).unsqueeze(1) + cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) + additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) + additive_mask.fill_(0) + additive_mask.masked_fill_(~cls_mask, float("-inf")) + additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) + return additive_mask + + def _embeds(self, text) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + cast_dtype = self.transformer.get_cast_dtype() + seq_len = text.shape[1] + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + attn_mask = self.attn_mask + if self.cls_emb is not None: + seq_len += 1 + x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1) + cls_mask = self.build_cls_mask(text, cast_dtype) + if attn_mask is not None: + attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] + x = x + self.positional_embedding[:seq_len].to(cast_dtype) + return x, attn_mask + + def forward_intermediates( + self, + text: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + stop_early: bool = False, + normalize_intermediates: bool = False, + intermediates_only: bool = False, + output_fmt: str = 'NCHW', + output_extra_tokens: bool = False, + ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + text: Input text ids + indices: Take last n blocks if int, all if None, select matching indices if sequence + stop_early: Stop iterating over blocks when last desired intermediate hit + normalize_intermediates: Apply norm layer to all intermediates + intermediates_only: Only return intermediate features + output_fmt: Shape of intermediate feature outputs + output_extra_tokens: Return both prefix and intermediate tokens + Returns: + + """ + assert output_fmt in ('NLC',), 'Output format must be NLC.' + # forward pass + x, attn_mask = self._embeds(text) + x, intermediates = self.transformer.forward_intermediates( + x, + attn_mask=attn_mask, + indices=indices, + stop_early=stop_early, + ) + + # process intermediates + if normalize_intermediates: + # apply final norm to all intermediates + intermediates = [self.ln_final(xi) for xi in intermediates] + + output = {} + + if self.cls_emb is not None: + seq_intermediates = [xi[:, :-1] for xi in intermediates] # separate concat'd class token from sequence + if output_extra_tokens: + # return suffix class tokens separately + cls_intermediates = [xi[:, -1:] for xi in intermediates] + output['text_intermediates_suffix'] = cls_intermediates + intermediates = seq_intermediates + output['text_intermediates'] = intermediates + + if intermediates_only: + return output + + if self.cls_emb is not None: + # presence of appended cls embed (CoCa) overrides pool_type, always take last token + pooled = text_global_pool(x, pool_type='last') + pooled = self.ln_final(pooled) # final LN applied after pooling in this case + else: + x = self.ln_final(x) + pooled = text_global_pool(x, text, pool_type=self.pool_type) + + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + pooled = self.text_projection(pooled) + else: + pooled = pooled @ self.text_projection + + output['text_features'] = pooled + + return output + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices = self.transformer.prune_intermediate_layers(indices) + if prune_norm: + self.ln_final = nn.Identity() + if prune_head: + self.text_projection = None + return take_indices + + def forward(self, text): + x, attn_mask = self._embeds(text) + + x = self.transformer(x, attn_mask=attn_mask) + + # x.shape = [batch_size, n_ctx, transformer.width] + if self.cls_emb is not None: + # presence of appended cls embed (CoCa) overrides pool_type, always take last token + pooled = text_global_pool(x, pool_type='last') + pooled = self.ln_final(pooled) # final LN applied after pooling in this case + tokens = x[:, :-1] + else: + x = self.ln_final(x) + pooled = text_global_pool(x, text, pool_type=self.pool_type) + tokens = x + + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + pooled = self.text_projection(pooled) + else: + pooled = pooled @ self.text_projection + + if self.output_tokens: + return pooled, tokens + + return pooled + + +class MultimodalTransformer(Transformer): + def __init__( + self, + width: int, + layers: int, + heads: int, + context_length: int = 77, + mlp_ratio: float = 4.0, + ls_init_value: float = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + output_dim: int = 512, + batch_first: bool = True, + ): + super().__init__( + width=width, + layers=layers, + heads=heads, + mlp_ratio=mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + batch_first=batch_first, + ) + self.context_length = context_length + self.cross_attn = nn.ModuleList([ + ResidualAttentionBlock( + width, + heads, + mlp_ratio, + ls_init_value=ls_init_value, + act_layer=act_layer, + norm_layer=norm_layer, + is_cross_attention=True, + batch_first=batch_first, + ) + for _ in range(layers) + ]) + + self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) + + self.ln_final = norm_layer(width) + self.text_projection = nn.Parameter(torch.empty(width, output_dim)) + + def init_parameters(self): + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + for block in self.transformer.cross_attn: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def forward_intermediates( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + indices: Optional[Union[int, List[int]]] = None, + stop_early: bool = False, + ): + assert False, "Not currently implemented for MultimodalTransformer w/ xattn" + + def forward(self, image_embs, text_embs): + seq_len = text_embs.shape[1] + if not self.batch_first: + image_embs = image_embs.permute(1, 0, 2) # NLD -> LND + text_embs = text_embs.permute(1, 0, 2) # NLD -> LND + + for resblock, cross_attn in zip(self.resblocks, self.cross_attn): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + text_embs = checkpoint( + resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len], use_reentrant=False) + text_embs = checkpoint( + cross_attn, text_embs, image_embs, image_embs, None, use_reentrant=False) + else: + text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) + text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) + + if not self.batch_first: + text_embs = text_embs.permute(1, 0, 2) # LND -> NLD + + out = self.ln_final(text_embs) + if self.text_projection is not None: + out = out @ self.text_projection + + return out + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + + +@dataclass +class CLIPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + head_width: int = 64 + mlp_ratio: float = 4.0 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + + ls_init_value: Optional[float] = None # layer scale initial value + patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results + attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) + attn_pooler_queries: int = 256 # n_queries for attentional pooler + attn_pooler_heads: int = 8 # n heads for attentional_pooling + no_ln_pre: bool = False # disable pre transformer LayerNorm + pos_embed_type: str = 'learnable' + final_ln_after_pool: bool = False # apply final LayerNorm after pooling + pool_type: str = 'tok' + output_tokens: bool = False + act_kwargs: Optional[dict] = None + norm_kwargs: Optional[dict] = None + + timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size + timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model + timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') + timm_proj_bias: bool = False # enable bias final projection + timm_drop: float = 0. # head dropout + timm_drop_path: Optional[float] = None # backbone stochastic depth + + +@dataclass +class CLIPTextCfg: + context_length: int = 77 + vocab_size: int = 49408 + hf_tokenizer_name: Optional[str] = None + tokenizer_kwargs: Optional[dict] = None + + width: int = 512 + heads: int = 8 + layers: int = 12 + mlp_ratio: float = 4.0 + ls_init_value: Optional[float] = None # layer scale initial value + embed_cls: bool = False + pad_id: int = 0 + no_causal_mask: bool = False # disable causal masking + final_ln_after_pool: bool = False # apply final LayerNorm after pooling + pool_type: str = 'argmax' + proj_bias: bool = False + proj_type: str = 'linear' # control final text projection, 'none' forces no projection + output_tokens: bool = False + act_kwargs: dict = None + norm_kwargs: dict = None + + # HuggingFace specific text tower config + hf_model_name: Optional[str] = None + hf_model_pretrained: bool = True + hf_proj_type: str = 'mlp' + hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models + + +def get_cast_dtype(precision: str): + cast_dtype = None + if precision == 'bf16': + cast_dtype = torch.bfloat16 + elif precision == 'fp16': + cast_dtype = torch.float16 + return cast_dtype + + +def get_input_dtype(precision: str): + input_dtype = None + if precision in ('bf16', 'pure_bf16'): + input_dtype = torch.bfloat16 + elif precision in ('fp16', 'pure_fp16'): + input_dtype = torch.float16 + return input_dtype + + +def _build_vision_tower( + embed_dim: int, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None +): + if isinstance(vision_cfg, dict): + vision_cfg = CLIPVisionCfg(**vision_cfg) + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if vision_cfg.timm_model_name: + visual = TimmModel( + vision_cfg.timm_model_name, + pretrained=vision_cfg.timm_model_pretrained, + pool=vision_cfg.timm_pool, + proj=vision_cfg.timm_proj, + proj_bias=vision_cfg.timm_proj_bias, + drop=vision_cfg.timm_drop, + drop_path=vision_cfg.timm_drop_path, + patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, + embed_dim=embed_dim, + image_size=vision_cfg.image_size, + ) + elif isinstance(vision_cfg.layers, (tuple, list)): + vision_heads = vision_cfg.width * 32 // vision_cfg.head_width + visual = ModifiedResNet( + layers=vision_cfg.layers, + output_dim=embed_dim, + heads=vision_heads, + image_size=vision_cfg.image_size, + width=vision_cfg.width, + ) + else: + vision_heads = vision_cfg.width // vision_cfg.head_width + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + if vision_cfg.norm_kwargs: + norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) + if vision_cfg.act_kwargs is not None: + act_layer = partial(act_layer, **vision_cfg.act_kwargs) + + visual = VisionTransformer( + image_size=vision_cfg.image_size, + patch_size=vision_cfg.patch_size, + width=vision_cfg.width, + layers=vision_cfg.layers, + heads=vision_heads, + mlp_ratio=vision_cfg.mlp_ratio, + ls_init_value=vision_cfg.ls_init_value, + patch_dropout=vision_cfg.patch_dropout, + attentional_pool=vision_cfg.attentional_pool, + attn_pooler_queries=vision_cfg.attn_pooler_queries, + attn_pooler_heads=vision_cfg.attn_pooler_heads, + pos_embed_type=vision_cfg.pos_embed_type, + no_ln_pre=vision_cfg.no_ln_pre, + final_ln_after_pool=vision_cfg.final_ln_after_pool, + pool_type=vision_cfg.pool_type, + output_tokens=vision_cfg.output_tokens, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return visual + + +def _build_text_tower( + embed_dim: int, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + if isinstance(text_cfg, dict): + text_cfg = CLIPTextCfg(**text_cfg) + + if text_cfg.hf_model_name: + text = HFTextEncoder( + text_cfg.hf_model_name, + output_dim=embed_dim, + proj_type=text_cfg.hf_proj_type, + pooler_type=text_cfg.hf_pooler_type, + pretrained=text_cfg.hf_model_pretrained, + output_tokens=text_cfg.output_tokens, + ) + else: + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + if text_cfg.norm_kwargs: + norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) + if text_cfg.act_kwargs is not None: + act_layer = partial(act_layer, **text_cfg.act_kwargs) + + text = TextTransformer( + context_length=text_cfg.context_length, + vocab_size=text_cfg.vocab_size, + width=text_cfg.width, + heads=text_cfg.heads, + layers=text_cfg.layers, + mlp_ratio=text_cfg.mlp_ratio, + ls_init_value=text_cfg.ls_init_value, + output_dim=embed_dim, + embed_cls=text_cfg.embed_cls, + no_causal_mask=text_cfg.no_causal_mask, + pad_id=text_cfg.pad_id, + pool_type=text_cfg.pool_type, + proj_type=text_cfg.proj_type, + proj_bias=text_cfg.proj_bias, + output_tokens=text_cfg.output_tokens, + act_layer=act_layer, + norm_layer=norm_layer, + ) + return text + + +class CLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + init_logit_scale: float = np.log(1 / 0.07), + init_logit_bias: Optional[float] = None, + nonscalar_logit_scale: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + + text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.transformer = text.transformer + self.context_length = text.context_length + self.vocab_size = text.vocab_size + self.token_embedding = text.token_embedding + self.positional_embedding = text.positional_embedding + self.ln_final = text.ln_final + self.text_projection = text.text_projection + self.text_pool_type = text.pool_type + self.register_buffer('attn_mask', text.attn_mask, persistent=False) + + lshape = [1] if nonscalar_logit_scale else [] + self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) + if init_logit_bias is not None: + self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) + else: + self.logit_bias = None + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + + @torch.jit.ignore + def no_weight_decay(self): + # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default + no_wd = {'positional_embedding'} + if hasattr(self.visual, 'no_weight_decay'): + for n in self.visual.no_weight_decay(): + no_wd.add('visual.' + n) + return no_wd + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = self.transformer(x, attn_mask=self.attn_mask) + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + x = text_global_pool(x, text, self.text_pool_type) + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + x = self.text_projection(x) + else: + x = x @ self.text_projection + + return F.normalize(x, dim=-1) if normalize else x + + def get_logits(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + image_logits = self.logit_scale.exp() * image_features @ text_features.T + if self.logit_bias is not None: + image_logits += self.logit_bias + text_logits = image_logits.T + return image_logits, text_logits + + def forward_intermediates( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + image_indices: Optional[Union[int, List[int]]] = None, + text_indices: Optional[Union[int, List[int]]] = None, + stop_early: bool = False, + normalize: bool = True, + normalize_intermediates: bool = False, + intermediates_only: bool = False, + image_output_fmt: str = 'NCHW', + image_output_extra_tokens: bool = False, + text_output_fmt: str = 'NLC', + text_output_extra_tokens: bool = False, + output_logits: bool = False, + output_logit_scale_bias: bool = False, + ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + image: Input image tensor + text: Input text tensor + image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence + text_indices: Take last n blocks if int, all if None, select matching indices if sequence + stop_early: Stop iterating over blocks when last desired intermediate hit + normalize_intermediates: Apply final norm layer to all intermediates + normalize: L2 Normalize final features + intermediates_only: Only return intermediate features, do not return final features + image_output_fmt: Shape of intermediate image feature outputs + image_output_extra_tokens: Return both prefix and spatial intermediate tokens + text_output_fmt: Shape of intermediate text feature outputs (ignored for this model) + text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model) + output_logits: Include logits in output + output_logit_scale_bias: Include the logit scale bias in the output + Returns: + + """ + output = {} + if intermediates_only: + # intermediates only disables final feature normalization, and include logits + normalize = False + output_logits = False + if output_logits: + assert image is not None and text is not None, 'Both image and text inputs are required to compute logits' + + if image is not None: + image_output = self.visual.forward_intermediates( + image, + indices=image_indices, + stop_early=stop_early, + normalize_intermediates=normalize_intermediates, + intermediates_only=intermediates_only, + output_fmt=image_output_fmt, + output_extra_tokens=image_output_extra_tokens, + ) + if normalize and "image_features" in image_output: + image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) + output.update(image_output) + + if text is not None: + cast_dtype = self.transformer.get_cast_dtype() + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding.to(cast_dtype) + x, intermediates = self.transformer.forward_intermediates( + x, + attn_mask=self.attn_mask, + indices=text_indices + ) + if normalize_intermediates: + intermediates = [self.ln_final(xi) for xi in intermediates] + + # NOTE this model doesn't support cls embed in text transformer, no need for extra intermediate tokens + output["text_intermediates"] = intermediates + + if not intermediates_only: + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + x = text_global_pool(x, text, self.text_pool_type) + if self.text_projection is not None: + if isinstance(self.text_projection, nn.Linear): + x = self.text_projection(x) + else: + x = x @ self.text_projection + if normalize: + x = F.normalize(x, dim=-1) + output["text_features"] = x + + logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None + + if output_logits: + image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T + if self.logit_bias is not None: + image_logits += self.logit_bias + text_logits = image_logits.T + output["image_logits"] = image_logits + output["text_logits"] = text_logits + + if output_logit_scale_bias: + output["logit_scale"] = logit_scale_exp + if self.logit_bias is not None: + output['logit_bias'] = self.logit_bias + + return output + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + + if self.output_dict: + out_dict = { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + if self.logit_bias is not None: + out_dict['logit_bias'] = self.logit_bias + return out_dict + + if self.logit_bias is not None: + return image_features, text_features, self.logit_scale.exp(), self.logit_bias + return image_features, text_features, self.logit_scale.exp() + + +class CustomTextCLIP(nn.Module): + output_dict: torch.jit.Final[bool] + + def __init__( + self, + embed_dim: int, + vision_cfg: CLIPVisionCfg, + text_cfg: CLIPTextCfg, + quick_gelu: bool = False, + init_logit_scale: float = np.log(1 / 0.07), + init_logit_bias: Optional[float] = None, + nonscalar_logit_scale: bool = False, + cast_dtype: Optional[torch.dtype] = None, + output_dict: bool = False, + ): + super().__init__() + self.output_dict = output_dict + self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) + self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) + self.context_length = self.text.context_length + self.vocab_size = self.text.vocab_size + + lshape = [1] if nonscalar_logit_scale else [] + self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) + if init_logit_bias is not None: + self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) + else: + self.logit_bias = None + + def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + + def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): + self.text.lock(unlocked_layers, freeze_layer_norm) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + + @torch.jit.ignore + def no_weight_decay(self): + # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default + no_wd = set() + if hasattr(self.visual, 'no_weight_decay'): + for n in self.visual.no_weight_decay(): + no_wd.add('visual.' + n) + if hasattr(self.text, 'no_weight_decay'): + for n in self.visual.no_weight_decay(): + no_wd.add('text.' + n) + return no_wd + + def encode_image(self, image, normalize: bool = False): + features = self.visual(image) + return F.normalize(features, dim=-1) if normalize else features + + def encode_text(self, text, normalize: bool = False): + features = self.text(text) + return F.normalize(features, dim=-1) if normalize else features + + def get_logits(self, image, text): + image_features = self.encode_image(image, normalize=True) + text_features = self.encode_text(text, normalize=True) + image_logits = self.logit_scale.exp() * image_features @ text_features.T + if self.logit_bias is not None: + image_logits += self.logit_bias + text_logits = image_logits.T + return image_logits, text_logits + + def forward_intermediates( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + image_indices: Optional[Union[int, List[int]]] = None, + text_indices: Optional[Union[int, List[int]]] = None, + stop_early: bool = False, + normalize: bool = True, + normalize_intermediates: bool = False, + intermediates_only: bool = False, + image_output_fmt: str = 'NCHW', + image_output_extra_tokens: bool = False, + text_output_fmt: str = 'NLC', + text_output_extra_tokens: bool = False, + output_logits: bool = False, + output_logit_scale_bias: bool = False, + ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + image: Input image tensor + text: Input text tensor + image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence + text_indices: Take last n blocks if int, all if None, select matching indices if sequence + stop_early: Stop iterating over blocks when last desired intermediate hit + normalize: L2 Normalize final image and text features (if present) + normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible) + intermediates_only: Only return intermediate features, do not return final features + image_output_fmt: Shape of intermediate image feature outputs + image_output_extra_tokens: Return both prefix and spatial intermediate tokens + text_output_fmt: Shape of intermediate text feature outputs + text_output_extra_tokens: Return both prefix and spatial intermediate tokens + output_logits: Include logits in output + output_logit_scale_bias: Include the logit scale bias in the output + Returns: + + """ + output = {} + if intermediates_only: + # intermediates only disables final feature normalization, and include logits + normalize = False + output_logits = False + if output_logits: + assert image is not None and text is not None, 'Both image and text inputs are required to compute logits' + + if image is not None: + image_output = self.visual.forward_intermediates( + image, + indices=image_indices, + stop_early=stop_early, + normalize_intermediates=normalize_intermediates, + intermediates_only=intermediates_only, + output_fmt=image_output_fmt, + output_extra_tokens=image_output_extra_tokens, + ) + if normalize and "image_features" in image_output: + image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) + output.update(image_output) + + if text is not None: + text_output = self.text.forward_intermediates( + text, + indices=text_indices, + stop_early=stop_early, + normalize_intermediates=normalize_intermediates, + intermediates_only=intermediates_only, + output_fmt=text_output_fmt, + output_extra_tokens=text_output_extra_tokens, + ) + if normalize and "text_features" in text_output: + text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1) + output.update(text_output) + + logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None + + if output_logits: + image_logits = logit_scale_exp * output["image_features"] @ output["text_features"].T + if self.logit_bias is not None: + image_logits += self.logit_bias + text_logits = image_logits.T + output["image_logits"] = image_logits + output["text_logits"] = text_logits + + if output_logit_scale_bias: + output["logit_scale"] = logit_scale_exp + if self.logit_bias is not None: + output['logit_bias'] = self.logit_bias + + return output + + def forward( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + ): + image_features = self.encode_image(image, normalize=True) if image is not None else None + text_features = self.encode_text(text, normalize=True) if text is not None else None + + if self.output_dict: + out_dict = { + "image_features": image_features, + "text_features": text_features, + "logit_scale": self.logit_scale.exp() + } + if self.logit_bias is not None: + out_dict['logit_bias'] = self.logit_bias + return out_dict + + if self.logit_bias is not None: + return image_features, text_features, self.logit_scale.exp(), self.logit_bias + return image_features, text_features, self.logit_scale.exp() + + +def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): + """Convert applicable model parameters to low-precision (bf16 or fp16)""" + + def _convert_weights(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.to(dtype) + if l.bias is not None: + l.bias.data = l.bias.data.to(dtype) + + if isinstance(l, (nn.MultiheadAttention, Attention)): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.to(dtype) + + if isinstance(l, (CLIP, TextTransformer)): + # convert text nn.Parameter projections + attr = getattr(l, "text_projection", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + if isinstance(l, VisionTransformer): + # convert vision nn.Parameter projections + attr = getattr(l, "proj", None) + if attr is not None: + attr.data = attr.data.to(dtype) + + model.apply(_convert_weights) + + +convert_weights_to_fp16 = convert_weights_to_lp # backwards compat + + +# used to maintain checkpoint compatibility +def convert_to_custom_text_state_dict(state_dict: dict): + if 'text_projection' in state_dict: + # old format state_dict, move text tower -> .text + new_state_dict = {} + for k, v in state_dict.items(): + if any(k.startswith(p) for p in ( + 'text_projection', + 'positional_embedding', + 'token_embedding', + 'transformer', + 'ln_final', + )): + k = 'text.' + k + new_state_dict[k] = v + return new_state_dict + return state_dict + + +def build_model_from_openai_state_dict( + state_dict: dict, + quick_gelu=True, + cast_dtype=torch.float16, +): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_size = vision_patch_size * grid_size + else: + counts: list = [ + len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_size = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + vision_cfg = CLIPVisionCfg( + layers=vision_layers, + width=vision_width, + patch_size=vision_patch_size, + image_size=image_size, + ) + text_cfg = CLIPTextCfg( + context_length=context_length, + vocab_size=vocab_size, + width=transformer_width, + heads=transformer_heads, + layers=transformer_layers, + ) + model = CLIP( + embed_dim, + vision_cfg=vision_cfg, + text_cfg=text_cfg, + quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU + cast_dtype=cast_dtype, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 + model.load_state_dict(state_dict) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device('cpu')): + model.eval() + image_size = model.visual.image_size + example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) + example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_images, example_text), + encode_text=(example_text,), + encode_image=(example_images,) + )) + model.visual.image_size = image_size + return model + + +def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): + # Rescale the grid of position embeddings when loading from state_dict + old_pos_embed = state_dict.get('visual.positional_embedding', None) + if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): + return + grid_size = to_2tuple(model.visual.grid_size) + extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) + new_seq_len = grid_size[0] * grid_size[1] + extra_tokens + if new_seq_len == old_pos_embed.shape[0]: + return + + if extra_tokens: + pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] + else: + pos_emb_tok, pos_emb_img = None, old_pos_embed + old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) + + logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) + pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) + pos_emb_img = F.interpolate( + pos_emb_img, + size=grid_size, + mode=interpolation, + antialias=antialias, + align_corners=False, + ) + pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] + if pos_emb_tok is not None: + new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) + else: + new_pos_embed = pos_emb_img + state_dict['visual.positional_embedding'] = new_pos_embed + + +def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False): + old_pos_embed = state_dict.get('positional_embedding', None) + if old_pos_embed is None: + return + # FIXME add support for text cls_token + model_pos_embed = getattr(model, 'positional_embedding', None) + if model_pos_embed is None: + model_pos_embed = getattr(model.text, 'positional_embedding', None) + + old_num_pos = old_pos_embed.shape[0] + old_width = old_pos_embed.shape[1] + num_pos = model_pos_embed.shape[0] + width = model_pos_embed.shape[1] + assert old_width == width, 'text pos_embed width changed!' + if old_num_pos == num_pos: + return + + logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos) + old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1) + old_pos_embed = F.interpolate( + old_pos_embed, + size=num_pos, + mode=interpolation, + antialias=antialias, + align_corners=False, + ) + old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] + new_pos_embed = old_pos_embed + + state_dict['positional_embedding'] = new_pos_embed + + +def get_model_preprocess_cfg(model): + module = getattr(model, 'visual', model) + preprocess_cfg = getattr(module, 'preprocess_cfg', {}) + if not preprocess_cfg: + # use separate legacy attributes if preprocess_cfg dict not found + size = getattr(module, 'image_size') + if size is not None: + preprocess_cfg['size'] = size + mean = getattr(module, 'image_mean', None) + if mean is not None: + preprocess_cfg['mean'] = mean + std = getattr(module, 'image_std', None) + if std is not None: + preprocess_cfg['std'] = std + return preprocess_cfg + + +def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): + module = getattr(model, 'visual', model) + module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat + module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat + module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict + + +def get_model_tokenize_cfg(model): + module = getattr(model, 'text', model) + cfg = {} + context_length = getattr(module, 'context_length', None) + if context_length is not None: + cfg['context_length'] = context_length + vocab_size = getattr(module, 'vocab_size', None) + if vocab_size is not None: + cfg['vocab_size'] = vocab_size + return cfg + + + +try: + from huggingface_hub import hf_hub_download + hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) + _has_hf_hub = True +except ImportError: + hf_hub_download = None + _has_hf_hub = False + + +def _pcfg(url='', hf_hub='', **kwargs): + # OpenAI / OpenCLIP defaults + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': OPENAI_DATASET_MEAN, + 'std': OPENAI_DATASET_STD, + 'interpolation': 'bicubic', + 'resize_mode': 'shortest', + **kwargs, + } + + +def _slpcfg(url='', hf_hub='', **kwargs): + # SiGLIP defaults + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': INCEPTION_MEAN, + 'std': INCEPTION_STD, + 'interpolation': 'bicubic', + 'resize_mode': 'squash', + **kwargs, + } + + +def _apcfg(url='', hf_hub='', **kwargs): + # CLIPA defaults + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': IMAGENET_MEAN, + 'std': IMAGENET_STD, + 'interpolation': 'bilinear', + 'resize_mode': 'squash', + **kwargs, + } + + +def _mccfg(url='', hf_hub='', **kwargs): + # MobileCLIP + return { + 'url': url, + 'hf_hub': hf_hub, + 'mean': (0., 0., 0.), + 'std': (1., 1., 1.), + 'interpolation': 'bilinear', + 'resize_mode': 'shortest', + **kwargs, + } + + + +_RN50 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + hf_hub="timm/resnet50_clip.openai/", + quick_gelu=True, + ), + yfcc15m=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + hf_hub="timm/resnet50_clip.yfcc15m/", + quick_gelu=True, + ), + cc12m=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", + hf_hub="timm/resnet50_clip.cc12m/", + quick_gelu=True, + ), +) + +_RN101 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + hf_hub="timm/resnet101_clip.openai/", + quick_gelu=True, + ), + yfcc15m=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", + hf_hub="timm/resnet101_clip.yfcc15m/", + quick_gelu=True, + ), +) + +_RN50x4 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + hf_hub="timm/resnet50x4_clip.openai/", + quick_gelu=True, + ), +) + +_RN50x16 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + hf_hub="timm/resnet50x16_clip.openai/", + quick_gelu=True, + ), +) + +_RN50x64 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + hf_hub="timm/resnet50x64_clip.openai/", + quick_gelu=True, + ), +) + +_VITB32 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + hf_hub="timm/vit_base_patch32_clip_224.openai/", + quick_gelu=True, + ), + # LAION 400M (quick gelu) + laion400m_e31=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + hf_hub="timm/vit_base_patch32_clip_224.laion400m_e31/", + quick_gelu=True, + ), + laion400m_e32=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + hf_hub="timm/vit_base_patch32_clip_224.laion400m_e32/", + quick_gelu=True, + ), + # LAION 2B-en + laion2b_e16=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth", + hf_hub="timm/vit_base_patch32_clip_224.laion2b_e16/", + ), + laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'), + # DataComp-M models + datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'), + commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'), + commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'), + commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'), + commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'), + commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'), + commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'), + # DataComp-S models + datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'), + commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'), + commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'), + commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'), + commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), + commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), + commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), + # MetaClip models (NOTE quick-gelu activation used) + metaclip_400m=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt", + hf_hub="timm/vit_base_patch32_clip_224.metaclip_400m/", + quick_gelu=True, + ), + metaclip_fullcc=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt", + hf_hub="timm/vit_base_patch32_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), +) + +_VITB32_256 = dict( + datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'), +) + +_VITB16 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + hf_hub="timm/vit_base_patch16_clip_224.openai/", + quick_gelu=True, + ), + # LAION-400M + laion400m_e31=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt", + hf_hub="timm/vit_base_patch16_clip_224.laion400m_e31/", + ), + laion400m_e32=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt", + hf_hub="timm/vit_base_patch16_clip_224.laion400m_e32/", + ), + # LAION-2B + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'), + # DataComp-L models + datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), + commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'), + commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'), + commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'), + commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'), + commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), + commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), + # DFN + dfn2b=_pcfg( + hf_hub='apple/DFN2B-CLIP-ViT-B-16/', + quick_gelu=True, + ), + # MetaCLIP (these are quick-gelu) + metaclip_400m=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt", + hf_hub="timm/vit_base_patch16_clip_224.metaclip_400m/", + quick_gelu=True, + ), + metaclip_fullcc=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt", + hf_hub="timm/vit_base_patch16_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), +) + +_VITB16_PLUS_240 = dict( + laion400m_e31=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt", + hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", + ), + laion400m_e32=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt", + hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", + ), +) + +_VITL14 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + hf_hub="timm/vit_large_patch14_clip_224.openai/", + quick_gelu=True, + ), + # LAION-400M + laion400m_e31=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt", + hf_hub="timm/vit_large_patch14_clip_224.laion400m_e31/", + ), + laion400m_e32=_pcfg( + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt", + hf_hub="timm/vit_large_patch14_clip_224.laion400m_e32/", + ), + # LAION-2B-en + laion2b_s32b_b82k=_pcfg( + hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', + mean=INCEPTION_MEAN, std=INCEPTION_STD), + # DataComp-XL models + datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), + commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), + commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), + commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), + # MetaCLIP + metaclip_400m=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt", + hf_hub="timm/vit_large_patch14_clip_224.metaclip_400m/", + quick_gelu=True, + ), + metaclip_fullcc=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt", + hf_hub="timm/vit_large_patch14_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), + # DFN-2B (quick-gelu) + dfn2b=_pcfg( + hf_hub='apple/DFN2B-CLIP-ViT-L-14/', + quick_gelu=True, + ), + # DFN-2B 39B SS + dfn2b_s39b=_pcfg( + hf_hub='apple/DFN2B-CLIP-ViT-L-14-39B/', + ), +) + +_VITL14_336 = dict( + openai=_pcfg( + url="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", + hf_hub="timm/vit_large_patch14_clip_336.openai/", + quick_gelu=True, + ), +) + +_VITH14 = dict( + # LAION-2B-en + laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), + # MetaCLIP (quick-gelu) + metaclip_fullcc=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt", + hf_hub="timm/vit_huge_patch14_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), + metaclip_altogether=_pcfg( + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_v1.2_altogether.pt", + hf_hub="timm/vit_huge_patch14_clip_224.metaclip_altogether/", + # NOTE unlike other MetaCLIP models, this is not using QuickGELU, yay! + ), + # DFN-5B (quick-gelu) + dfn5b=_pcfg( + hf_hub='apple/DFN5B-CLIP-ViT-H-14/', + quick_gelu=True, + interpolation="bicubic", + resize_mode="squash" + ), +) + +_VITH14_378 = dict( + # DFN-5B (quick-gelu) + dfn5b=_pcfg( + hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/', + quick_gelu=True, + interpolation="bicubic", + resize_mode="squash" + ), +) + +_VITg14 = dict( + laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), + laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), +) + +_VITbigG14 = dict( + # LAION-2B-en + laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), + # MetaCLIP (quick-gelu) + metaclip_fullcc=_pcfg( + url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt', + hf_hub="timm/vit_gigantic_patch14_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), +) + +_robertaViTB32 = dict( + laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), +) + +_xlmRobertaBaseViTB32 = dict( + laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), +) + +_xlmRobertaLargeFrozenViTH14 = dict( + frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), +) + +_convnext_base = dict( + laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), +) + +_convnext_base_w = dict( + laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), + laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), +) + +_convnext_base_w_320 = dict( + laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), + laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), +) + +_convnext_large_d = dict( + laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), +) + +_convnext_large_d_320 = dict( + laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), + laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), +) + +_convnext_xxlarge = dict( + laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), + laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), + laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), +) + +_coca_VITB32 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') +) + +_coca_VITL14 = dict( + laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), + mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') +) + + +_PRETRAINED = { + "RN50": _RN50, + "RN101": _RN101, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "RN50x64": _RN50x64, + + "ViT-B-32": _VITB32, + "ViT-B-32-256": _VITB32_256, + "ViT-B-16": _VITB16, + "ViT-B-16-plus-240": _VITB16_PLUS_240, + "ViT-L-14": _VITL14, + "ViT-L-14-336": _VITL14_336, + "ViT-H-14": _VITH14, + "ViT-H-14-378": _VITH14_378, + "ViT-g-14": _VITg14, + "ViT-bigG-14": _VITbigG14, + + "roberta-ViT-B-32": _robertaViTB32, + "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, + "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, + + "convnext_base": _convnext_base, + "convnext_base_w": _convnext_base_w, + "convnext_base_w_320": _convnext_base_w_320, + "convnext_large_d": _convnext_large_d, + "convnext_large_d_320": _convnext_large_d_320, + "convnext_xxlarge": _convnext_xxlarge, + + "coca_ViT-B-32": _coca_VITB32, + "coca_ViT-L-14": _coca_VITL14, + + "EVA01-g-14": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt + laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), + ), + "EVA01-g-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt + merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), + ), + "EVA02-B-16": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt + merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), + ), + "EVA02-L-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt + merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), + ), + "EVA02-L-14-336": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt + merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), + ), + "EVA02-E-14": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt + laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), + ), + "EVA02-E-14-plus": dict( + # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt + laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), + ), + + "ViT-B-16-SigLIP": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'), + ), + "ViT-B-16-SigLIP-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'), + ), + "ViT-B-16-SigLIP-i18n-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'), + ), + "ViT-B-16-SigLIP-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'), + ), + "ViT-B-16-SigLIP-512": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'), + ), + "ViT-L-16-SigLIP-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'), + ), + "ViT-L-16-SigLIP-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'), + ), + "ViT-SO400M-14-SigLIP": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'), + ), + "ViT-SO400M-16-SigLIP-i18n-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP-i18n-256/'), + ), + "ViT-SO400M-14-SigLIP-378": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), # NOTE using 384 weights, but diff img_size used + ), + "ViT-SO400M-14-SigLIP-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), + ), + + "ViT-B-32-SigLIP2-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-32-SigLIP2-256/'), + ), + "ViT-B-16-SigLIP2": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2/'), + ), + "ViT-B-16-SigLIP2-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-256/'), + ), + "ViT-B-16-SigLIP2-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-384/'), + ), + "ViT-B-16-SigLIP2-512": dict( + webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-512/'), + ), + "ViT-L-16-SigLIP2-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-256/'), + ), + "ViT-L-16-SigLIP2-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-384/'), + ), + "ViT-L-16-SigLIP2-512": dict( + webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-512/'), + ), + "ViT-SO400M-14-SigLIP2": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2/'), + ), + "ViT-SO400M-14-SigLIP2-378": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2-378/'), + ), + "ViT-SO400M-16-SigLIP2-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-256/'), + ), + "ViT-SO400M-16-SigLIP2-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-384/'), + ), + "ViT-SO400M-16-SigLIP2-512": dict( + webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-512/'), + ), + "ViT-gopt-16-SigLIP2-256": dict( + webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-256/'), + ), + "ViT-gopt-16-SigLIP2-384": dict( + webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-384/'), + ), + + "ViT-L-14-CLIPA": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'), + ), + "ViT-L-14-CLIPA-336": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'), + ), + "ViT-H-14-CLIPA": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'), + ), + "ViT-H-14-CLIPA-336": dict( + laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'), + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'), + ), + "ViT-bigG-14-CLIPA": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'), + ), + "ViT-bigG-14-CLIPA-336": dict( + datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'), + ), + + "nllb-clip-base": dict( + v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'), + ), + "nllb-clip-large": dict( + v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'), + ), + + "nllb-clip-base-siglip": dict( + v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'), + mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'), + ), + "nllb-clip-large-siglip": dict( + v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'), + mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'), + ), + + "MobileCLIP-S1": dict( + datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')), + "MobileCLIP-S2": dict( + datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')), + "MobileCLIP-B": dict( + datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'), + datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'), + ), + + "ViTamin-S": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'), + ), + "ViTamin-S-LTT": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'), + ), + "ViTamin-B": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'), + ), + "ViTamin-B-LTT": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'), + ), + "ViTamin-L": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'), + ), + "ViTamin-L-256": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'), + ), + "ViTamin-L-336": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'), + ), + "ViTamin-L-384": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'), + ), + "ViTamin-L2": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'), + ), + "ViTamin-L2-256": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'), + ), + "ViTamin-L2-336": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'), + ), + "ViTamin-L2-384": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'), + ), + "ViTamin-XL-256": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'), + ), + "ViTamin-XL-336": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'), + ), + "ViTamin-XL-384": dict( + datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'), + ), +} + +_PRETRAINED_quickgelu = {} +for k, v in _PRETRAINED.items(): + quick_gelu_tags = {} + for tk, tv in v.items(): + if tv.get('quick_gelu', False): + quick_gelu_tags[tk] = copy.deepcopy(tv) + if quick_gelu_tags: + _PRETRAINED_quickgelu[k + '-quickgelu'] = quick_gelu_tags +_PRETRAINED.update(_PRETRAINED_quickgelu) + +def _clean_tag(tag: str): + # normalize pretrained tags + return tag.lower().replace('-', '_') + + +def list_pretrained(as_str: bool = False): + """ returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] + + +def list_pretrained_models_by_tag(tag: str): + """ return all models having the specified pretrain tag """ + models = [] + tag = _clean_tag(tag) + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_tags_by_model(model: str): + """ return all pretrain tags for the specified model architecture """ + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def is_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return False + return _clean_tag(tag) in _PRETRAINED[model] + + +def get_pretrained_cfg(model: str, tag: str): + if model not in _PRETRAINED: + return {} + model_pretrained = _PRETRAINED[model] + return model_pretrained.get(_clean_tag(tag), {}) + + +def get_pretrained_url(model: str, tag: str): + cfg = get_pretrained_cfg(model, _clean_tag(tag)) + return cfg.get('url', '') + + +def download_pretrained_from_url( + url: str, + cache_dir: Optional[str] = None, +): + if not cache_dir: + cache_dir = os.path.expanduser("~/.cache/clip") + os.makedirs(cache_dir, exist_ok=True) + filename = os.path.basename(url) + + if 'openaipublic' in url: + expected_sha256 = url.split("/")[-2] + elif 'mlfoundations' in url: + expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] + else: + expected_sha256 = '' + + download_target = os.path.join(cache_dir, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): + raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def has_hf_hub(necessary=False): + if not _has_hf_hub and necessary: + # if no HF Hub module installed, and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return _has_hf_hub + + +def _get_safe_alternatives(filename: str) -> Iterable[str]: + """Returns potential safetensors alternatives for a given filename. + + Use case: + When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it. + """ + if filename == HF_WEIGHTS_NAME: + yield HF_SAFE_WEIGHTS_NAME + + if filename not in (HF_WEIGHTS_NAME,) and (filename.endswith(".bin") or filename.endswith(".pth")): + yield filename[:-4] + ".safetensors" + + +def download_pretrained_from_hf( + model_id: str, + filename: Optional[str] = None, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, +): + has_hf_hub(True) + + filename = filename or HF_WEIGHTS_NAME + + # Look for .safetensors alternatives and load from it if it exists + if _has_safetensors: + for safe_filename in _get_safe_alternatives(filename): + try: + cached_file = hf_hub_download( + repo_id=model_id, + filename=safe_filename, + revision=revision, + cache_dir=cache_dir, + ) + return cached_file + except Exception: + pass + + try: + # Attempt to download the file + cached_file = hf_hub_download( + repo_id=model_id, + filename=filename, + revision=revision, + cache_dir=cache_dir, + ) + return cached_file # Return the path to the downloaded file if successful + except Exception as e: + raise FileNotFoundError(f"Failed to download file ({filename}) for {model_id}. Last error: {e}") + + +def download_pretrained( + cfg: Dict, + prefer_hf_hub: bool = True, + cache_dir: Optional[str] = None, +): + target = '' + if not cfg: + return target + + if 'file' in cfg: + return cfg['file'] + + has_hub = has_hf_hub() + download_url = cfg.get('url', '') + download_hf_hub = cfg.get('hf_hub', '') + if has_hub and prefer_hf_hub and download_hf_hub: + # prefer to use HF hub, remove url info + download_url = '' + + if download_url: + target = download_pretrained_from_url(download_url, cache_dir=cache_dir) + elif download_hf_hub: + has_hf_hub(True) + # we assume the hf_hub entries in pretrained config combine model_id + filename in + # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and + # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. + model_id, filename = os.path.split(download_hf_hub) + if filename: + target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) + else: + target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + + return target + +# ================================================================== +def merge_preprocess_dict( + base: Union[PreprocessCfg, Dict], + overlay: Dict, +): + """ Merge overlay key-value pairs on top of base preprocess cfg or dict. + Input dicts are filtered based on PreprocessCfg fields. + """ + if isinstance(base, PreprocessCfg): + base_clean = asdict(base) + else: + base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} + if overlay: + overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} + base_clean.update(overlay_clean) + return base_clean + + +def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): + return merge_preprocess_dict(base, kwargs) + + +@dataclass +class PreprocessCfg: + size: Union[int, Tuple[int, int]] = 224 + mode: str = 'RGB' + mean: Tuple[float, ...] = OPENAI_DATASET_MEAN + std: Tuple[float, ...] = OPENAI_DATASET_STD + interpolation: str = 'bicubic' + resize_mode: str = 'shortest' + fill_color: int = 0 + + def __post_init__(self): + assert self.mode in ('RGB',) + + @property + def num_channels(self): + return 3 + + @property + def input_size(self): + return (self.num_channels,) + to_2tuple(self.size) + + + + +@dataclass +class PreprocessCfg: + size: Union[int, Tuple[int, int]] = 224 + mode: str = 'RGB' + mean: Tuple[float, ...] = OPENAI_DATASET_MEAN + std: Tuple[float, ...] = OPENAI_DATASET_STD + interpolation: str = 'bicubic' + resize_mode: str = 'shortest' + fill_color: int = 0 + + def __post_init__(self): + assert self.mode in ('RGB',) + + @property + def num_channels(self): + return 3 + + @property + def input_size(self): + return (self.num_channels,) + to_2tuple(self.size) + +_PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys()) + + +def merge_preprocess_dict( + base: Union[PreprocessCfg, Dict], + overlay: Dict, +): + """ Merge overlay key-value pairs on top of base preprocess cfg or dict. + Input dicts are filtered based on PreprocessCfg fields. + """ + if isinstance(base, PreprocessCfg): + base_clean = asdict(base) + else: + base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} + if overlay: + overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} + base_clean.update(overlay_clean) + return base_clean + + +def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): + return merge_preprocess_dict(base, kwargs) + + +@dataclass +class AugmentationCfg: + scale: Tuple[float, float] = (0.9, 1.0) + ratio: Optional[Tuple[float, float]] = None + color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None + re_prob: Optional[float] = None + re_count: Optional[int] = None + use_timm: bool = False + + # params for simclr_jitter_gray + color_jitter_prob: float = None + gray_scale_prob: float = None + + +def _setup_size(size, error_msg): + if isinstance(size, numbers.Number): + return int(size), int(size) + + if isinstance(size, Sequence) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size + + +class ResizeKeepRatio: + """ Resize and Keep Ratio + + Copy & paste from `timm` + """ + + def __init__( + self, + size, + longest=0., + interpolation=InterpolationMode.BICUBIC, + random_scale_prob=0., + random_scale_range=(0.85, 1.05), + random_aspect_prob=0., + random_aspect_range=(0.9, 1.11) + ): + if isinstance(size, (list, tuple)): + self.size = tuple(size) + else: + self.size = (size, size) + self.interpolation = interpolation + self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest + self.random_scale_prob = random_scale_prob + self.random_scale_range = random_scale_range + self.random_aspect_prob = random_aspect_prob + self.random_aspect_range = random_aspect_range + + @staticmethod + def get_params( + img, + target_size, + longest, + random_scale_prob=0., + random_scale_range=(0.85, 1.05), + random_aspect_prob=0., + random_aspect_range=(0.9, 1.11) + ): + """Get parameters + """ + source_size = img.size[::-1] # h, w + h, w = source_size + target_h, target_w = target_size + ratio_h = h / target_h + ratio_w = w / target_w + ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) + if random_scale_prob > 0 and random.random() < random_scale_prob: + ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) + ratio_factor = (ratio_factor, ratio_factor) + else: + ratio_factor = (1., 1.) + if random_aspect_prob > 0 and random.random() < random_aspect_prob: + aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1]) + ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) + size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)] + return size + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size + """ + size = self.get_params( + img, self.size, self.longest, + self.random_scale_prob, self.random_scale_range, + self.random_aspect_prob, self.random_aspect_range + ) + img = F.resize(img, size, self.interpolation) + return img + + def __repr__(self): + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += f', interpolation={self.interpolation})' + format_string += f', longest={self.longest:.3f})' + return format_string + + +def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor: + """Center crops and/or pads the given image. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + img (PIL Image or Tensor): Image to be cropped. + output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, + it is used for both directions. + fill (int, Tuple[int]): Padding color + + Returns: + PIL Image or Tensor: Cropped image. + """ + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: + output_size = (output_size[0], output_size[0]) + + _, image_height, image_width = F.get_dimensions(img) + crop_height, crop_width = output_size + + if crop_width > image_width or crop_height > image_height: + padding_ltrb = [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + img = F.pad(img, padding_ltrb, fill=fill) + _, image_height, image_width = F.get_dimensions(img) + if crop_width == image_width and crop_height == image_height: + return img + + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return F.crop(img, crop_top, crop_left, crop_height, crop_width) + + +class CenterCropOrPad(torch.nn.Module): + """Crops the given image at the center. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + """ + + def __init__(self, size, fill=0): + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.fill = fill + + def forward(self, img): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + return center_crop_or_pad(img, self.size, fill=self.fill) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size})" + + +def _convert_to_rgb(image): + return image.convert('RGB') + + +class color_jitter(object): + """ + Apply Color Jitter to the PIL image with a specified probability. + """ + def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8): + assert 0. <= p <= 1. + self.p = p + self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) + + def __call__(self, img): + if random.random() < self.p: + return self.transf(img) + else: + return img + + +class gray_scale(object): + """ + Apply Gray Scale to the PIL image with a specified probability. + """ + def __init__(self, p=0.2): + assert 0. <= p <= 1. + self.p = p + self.transf = Grayscale(num_output_channels=3) + + def __call__(self, img): + if random.random() < self.p: + return self.transf(img) + else: + return img + + +def image_transform( + image_size: Union[int, Tuple[int, int]], + is_train: bool, + mean: Optional[Tuple[float, ...]] = None, + std: Optional[Tuple[float, ...]] = None, + resize_mode: Optional[str] = None, + interpolation: Optional[str] = None, + fill_color: int = 0, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + mean = mean or OPENAI_DATASET_MEAN + if not isinstance(mean, (list, tuple)): + mean = (mean,) * 3 + + std = std or OPENAI_DATASET_STD + if not isinstance(std, (list, tuple)): + std = (std,) * 3 + + interpolation = interpolation or 'bicubic' + assert interpolation in ['bicubic', 'bilinear', 'random'] + # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set + interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC + + resize_mode = resize_mode or 'shortest' + assert resize_mode in ('shortest', 'longest', 'squash') + + if isinstance(aug_cfg, dict): + aug_cfg = AugmentationCfg(**aug_cfg) + else: + aug_cfg = aug_cfg or AugmentationCfg() + + normalize = Normalize(mean=mean, std=std) + + if is_train: + aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} + use_timm = aug_cfg_dict.pop('use_timm', False) + if use_timm: + from timm.data import create_transform # timm can still be optional + if isinstance(image_size, (tuple, list)): + assert len(image_size) >= 2 + input_size = (3,) + image_size[-2:] + else: + input_size = (3, image_size, image_size) + + aug_cfg_dict.setdefault('color_jitter', None) # disable by default + # drop extra non-timm items + aug_cfg_dict.pop('color_jitter_prob', None) + aug_cfg_dict.pop('gray_scale_prob', None) + + train_transform = create_transform( + input_size=input_size, + is_training=True, + hflip=0., + mean=mean, + std=std, + re_mode='pixel', + interpolation=interpolation, + **aug_cfg_dict, + ) + else: + train_transform = [ + RandomResizedCrop( + image_size, + scale=aug_cfg_dict.pop('scale'), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ] + if aug_cfg.color_jitter_prob: + assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4 + train_transform.extend([ + color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob) + ]) + if aug_cfg.gray_scale_prob: + train_transform.extend([ + gray_scale(aug_cfg.gray_scale_prob) + ]) + train_transform.extend([ + ToTensor(), + normalize, + ]) + train_transform = Compose(train_transform) + if aug_cfg_dict: + warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') + return train_transform + else: + if resize_mode == 'longest': + transforms = [ + ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1), + CenterCropOrPad(image_size, fill=fill_color) + ] + elif resize_mode == 'squash': + if isinstance(image_size, int): + image_size = (image_size, image_size) + transforms = [ + Resize(image_size, interpolation=interpolation_mode), + ] + else: + assert resize_mode == 'shortest' + if not isinstance(image_size, (tuple, list)): + image_size = (image_size, image_size) + if image_size[0] == image_size[1]: + # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) + transforms = [ + Resize(image_size[0], interpolation=interpolation_mode) + ] + else: + # resize shortest edge to matching target dim for non-square target + transforms = [ResizeKeepRatio(image_size)] + transforms += [CenterCrop(image_size)] + + transforms.extend([ + _convert_to_rgb, + ToTensor(), + normalize, + ]) + return Compose(transforms) + + +def image_transform_v2( + cfg: PreprocessCfg, + is_train: bool, + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, +): + return image_transform( + image_size=cfg.size, + is_train=is_train, + mean=cfg.mean, + std=cfg.std, + interpolation=cfg.interpolation, + resize_mode=cfg.resize_mode, + fill_color=cfg.fill_color, + aug_cfg=aug_cfg, + ) + +@dataclass +class AugmentationCfg: + scale: Tuple[float, float] = (0.9, 1.0) + ratio: Optional[Tuple[float, float]] = None + color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None + re_prob: Optional[float] = None + re_count: Optional[int] = None + use_timm: bool = False + + # params for simclr_jitter_gray + color_jitter_prob: float = None + gray_scale_prob: float = None + +def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): + module = getattr(model, 'visual', model) + module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat + module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat + module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict + + +@torch.no_grad() +def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True): + + def _convert_timm_img(state_dict): + if fastvit: + from timm.models.fastvit import checkpoint_filter_fn + else: + from timm.models.vision_transformer_hybrid import checkpoint_filter_fn + timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk) + timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()} + return timm_state_dict + + def _convert_openclip_txt(state_dict, prefix='text_encoder.'): + text_dict = {} + for k, v in state_dict.items(): + if not k.startswith(prefix): + continue + k = k.replace(prefix, '') + k = k.replace('projection_layer', 'text_projection') + k = k.replace('embedding_layer', 'token_embedding') + if k.startswith('positional_embedding.pos_embed.pos_embed'): + k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding') + v = v.squeeze() + k = k.replace('final_layer_norm', 'ln_final') + k = k.replace('pre_norm_mha.0', 'ln_1') + k = k.replace('pre_norm_mha.1', 'attn') + k = k.replace('pre_norm_ffn.0', 'ln_2') + k = k.replace('pre_norm_ffn.1', 'mlp.c_fc') + k = k.replace('pre_norm_ffn.4', 'mlp.c_proj') + k = k.replace('qkv_proj.weight', 'in_proj_weight') + k = k.replace('qkv_proj.bias', 'in_proj_bias') + k = k.replace('transformer.', 'transformer.resblocks.') + text_dict['text.' + k] = v + return text_dict + + image_dict = _convert_timm_img(state_dict) + text_dict = _convert_openclip_txt(state_dict) + out_dict = {**image_dict, **text_dict} + out_dict['logit_scale'] = state_dict['logit_scale'] + return out_dict + + +def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict): + if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict: + # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported) + state_dict = convert_mobile_clip_state_dict(model, state_dict) + if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict: + # convert b model + state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False) + return state_dict + +def load_state_dict( + checkpoint_path: str, + device='cpu', + weights_only=True, +): + # Check if safetensors or not and load weights accordingly + if str(checkpoint_path).endswith(".safetensors"): + from safetensors.torch import load_file + checkpoint = load_file(checkpoint_path, device=device) + else: + try: + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only) + except TypeError: + checkpoint = torch.load(checkpoint_path, map_location=device) + + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif isinstance(checkpoint, torch.jit.ScriptModule): + state_dict = checkpoint.state_dict() + for key in ["input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + else: + state_dict = checkpoint + if next(iter(state_dict.items()))[0].startswith('module'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + return state_dict + +def load_checkpoint( + model: Union[CLIP, CustomTextCLIP], + checkpoint_path: str, + strict: bool = True, + weights_only: bool = True, + device='cpu', +): + if Path(checkpoint_path).suffix in ('.npz', '.npy'): + # Separate path loading numpy big_vision (SigLIP) weights + from open_clip.convert import load_big_vision_weights + load_big_vision_weights(model, checkpoint_path) + return {} + + state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only) + + # Detect & convert 3rd party state_dicts -> open_clip + state_dict = convert_state_dict(model, state_dict) + + # Detect old format and make compatible with new format + if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): + state_dict = convert_to_custom_text_state_dict(state_dict) + + # correct if logit_scale differs in being scaler vs 1d param + if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim: + state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape) + + # correct if logit_bias differs in being scaler vs 1d param + if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim: + state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape) + + # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712 + if 'logit_bias' not in state_dict and model.logit_bias is not None: + state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) + + # Certain text transformers no longer expect position_ids after transformers==4.31 + position_id_key = 'text.transformer.embeddings.position_ids' + if position_id_key in state_dict and not hasattr(model, position_id_key): + del state_dict[position_id_key] + + resize_pos_embed(state_dict, model) + resize_text_pos_embed(state_dict, model) + + # Finally, load the massaged state_dict into model + incompatible_keys = model.load_state_dict(state_dict, strict=strict) + return incompatible_keys + +# /home/IITB/ai-at-ieor/23m1521/.conda/envs/openclip2/lib/python3.11/site-packages/open_clip/factory.py +HF_HUB_PREFIX = 'hf-hub:' +# _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIG_PATHS = [Path("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/model_configs")] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs + +import json + +def _get_hf_config( + model_id: str, + cache_dir: Optional[str] = None, +): + """ Fetch model config from HuggingFace Hub. + """ + config_path = download_pretrained_from_hf( + model_id, + filename='open_clip_config.json', + cache_dir=cache_dir, + ) + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + return config + +def get_model_config(model_name): + """ Fetch model config from builtin (local library) configs. + """ + if model_name in _MODEL_CONFIGS: + return copy.deepcopy(_MODEL_CONFIGS[model_name]) + else: + return None + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = ('.json',) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f'*{ext}')) + + for cf in config_files: + with open(cf, 'r') as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} + + +_rescan_model_configs() # initial populate of model config registry + +def list_models(): + """ enumerate available model architectures based on config files """ + return list(_MODEL_CONFIGS.keys()) + + +def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + else: + position_ids = None + return { + "text": input_ids, + "images": image_inputs, + "past_key_values": past, + "position_ids": position_ids, + "attention_mask": attention_mask, + } + +@dataclass +class MultimodalCfg(CLIPTextCfg): + mlp_ratio: int = 4 + dim_head: int = 64 + heads: int = 8 + n_queries: int = 256 + attn_pooler_heads: int = 8 + +try: + from transformers import ( + BeamSearchScorer, + LogitsProcessorList, + TopPLogitsWarper, + TopKLogitsWarper, + RepetitionPenaltyLogitsProcessor, + MinLengthLogitsProcessor, + MaxLengthCriteria, + StopStringCriteria, + EosTokenCriteria, + StoppingCriteriaList + ) + + GENERATION_TYPES = { + "top_k": TopKLogitsWarper, + "top_p": TopPLogitsWarper, + "beam_search": "beam_search" + } + _has_transformers = True +except ImportError as e: + GENERATION_TYPES = { + "top_k": None, + "top_p": None, + "beam_search": "beam_search" + } + _has_transformers = False + +def _token_to_tensor(token_id, device: str = "cpu") -> torch.Tensor: + if not isinstance(token_id, torch.Tensor): + if isinstance(token_id, int): + token_id = [token_id] + token_id = torch.tensor(token_id, device=device) + return token_id + + +def _build_text_decoder_tower( + embed_dim, + multimodal_cfg, + quick_gelu: bool = False, + cast_dtype: Optional[torch.dtype] = None, +): + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + act_layer = QuickGELU if quick_gelu else nn.GELU + norm_layer = ( + LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm + ) + + decoder = MultimodalTransformer( + context_length=multimodal_cfg.context_length, + width=multimodal_cfg.width, + heads=multimodal_cfg.heads, + layers=multimodal_cfg.layers, + ls_init_value=multimodal_cfg.ls_init_value, + output_dim=embed_dim, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + return decoder + +class CoCa(nn.Module): + def __init__( + self, + embed_dim, + multimodal_cfg: MultimodalCfg, + text_cfg: CLIPTextCfg, + vision_cfg: CLIPVisionCfg, + quick_gelu: bool = False, + init_logit_scale: float = np.log(1 / 0.07), + init_logit_bias: Optional[float] = None, + nonscalar_logit_scale: bool = False, + cast_dtype: Optional[torch.dtype] = None, + pad_id: int = 0, + ): + super().__init__() + multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg + text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg + vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg + + self.text = _build_text_tower( + embed_dim=embed_dim, + text_cfg=text_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + vocab_size = ( + text_cfg.vocab_size # for hf models + if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None + else text_cfg.vocab_size + ) + + self.visual = _build_vision_tower( + embed_dim=embed_dim, + vision_cfg=vision_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + self.text_decoder = _build_text_decoder_tower( + vocab_size, + multimodal_cfg=multimodal_cfg, + quick_gelu=quick_gelu, + cast_dtype=cast_dtype, + ) + + lshape = [1] if nonscalar_logit_scale else [] + self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) + if init_logit_bias is not None: + self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) + else: + self.logit_bias = None + self.pad_id = pad_id + + self.context_length = multimodal_cfg.context_length + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True): + self.visual.set_grad_checkpointing(enable) + self.text.set_grad_checkpointing(enable) + self.text_decoder.set_grad_checkpointing(enable) + + def _encode_image(self, images, normalize: bool = True): + image_latent, tokens_embs = self.visual(images) + image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent + return image_latent, tokens_embs + + def _encode_text(self, text, normalize: bool = True): + text_latent, token_emb = self.text(text) + text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent + return text_latent, token_emb + + def encode_image(self, images, normalize: bool = True): + image_latent, _ = self._encode_image(images, normalize=normalize) + return image_latent + + def encode_text(self, text, normalize: bool = True): + text_latent, _ = self._encode_text(text, normalize=normalize) + return text_latent + + def forward_intermediates( + self, + image: Optional[torch.Tensor] = None, + text: Optional[torch.Tensor] = None, + image_indices: Optional[Union[int, List[int]]] = None, + text_indices: Optional[Union[int, List[int]]] = None, + stop_early: bool = False, + normalize: bool = True, + normalize_intermediates: bool = False, + intermediates_only: bool = False, + image_output_fmt: str = 'NCHW', + image_output_extra_tokens: bool = False, + text_output_fmt: str = 'NLC', + text_output_extra_tokens: bool = False, + output_logits: bool = False, + output_logit_scale_bias: bool = False, + ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + image: Input image tensor + text: Input text tensor + image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence + text_indices: Take last n blocks if int, all if None, select matching indices if sequence + stop_early: Stop iterating over blocks when last desired intermediate hit + normalize: L2 Normalize final image and text features (if present) + normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible) + intermediates_only: Only return intermediate features, do not return final features + image_output_fmt: Shape of intermediate image feature outputs + image_output_extra_tokens: Return both prefix and spatial intermediate tokens + text_output_fmt: Shape of intermediate text feature outputs + text_output_extra_tokens: Return both prefix and spatial intermediate tokens + output_logits: Include logits in output + output_logit_scale_bias: Include the logit scale bias in the output + Returns: + + """ + output = {} + if intermediates_only: + # intermediates only disables final feature normalization, and include logits + normalize = False + output_logits = False + if output_logits: + assert False, 'FIXME, needs implementing' + + if image is not None: + image_output = self.visual.forward_intermediates( + image, + indices=image_indices, + stop_early=stop_early, + normalize_intermediates=normalize_intermediates, + intermediates_only=intermediates_only, + output_fmt=image_output_fmt, + output_extra_tokens=image_output_extra_tokens, + ) + if normalize and "image_features" in image_output: + image_output["image_features"] = F.normalize(image_output["image_features"], dim=-1) + output.update(image_output) + + if text is not None: + text_output = self.text.forward_intermediates( + text, + indices=text_indices, + stop_early=stop_early, + normalize_intermediates=normalize_intermediates, + intermediates_only=intermediates_only, + output_fmt=text_output_fmt, + output_extra_tokens=text_output_extra_tokens, + ) + if normalize and "text_features" in text_output: + text_output["text_features"] = F.normalize(text_output["text_features"], dim=-1) + output.update(text_output) + + # FIXME text decoder + logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None + if output_logit_scale_bias: + output["logit_scale"] = logit_scale_exp + if self.logit_bias is not None: + output['logit_bias'] = self.logit_bias + + return output + + def forward( + self, + image, + text: Optional[torch.Tensor] = None, + image_latent: Optional[torch.Tensor] = None, + image_embs: Optional[torch.Tensor] = None, + output_labels: bool = True, + ): + if image_latent is None or image_embs is None: + image_latent, image_embs = self._encode_image(image) + + if text is None: + return {"image_features": image_latent, "image_embs": image_embs} + + text_latent, token_embs = self._encode_text(text) + + # FIXME this isn't an ideal solution, would like to improve -RW + labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None + if output_labels: + # align text_embs and thus logits with labels for teacher-forcing caption loss + token_embs = token_embs[:, :-1] + + logits = self.text_decoder(image_embs, token_embs) + out_dict = { + "image_features": image_latent, + "text_features": text_latent, + "logits": logits, + "logit_scale": self.logit_scale.exp() + } + if labels is not None: + out_dict["labels"] = labels + if self.logit_bias is not None: + out_dict["logit_bias"] = self.logit_bias + return out_dict + + def generate( + self, + image, + text=None, + seq_len=30, + max_seq_len=77, + temperature=1., + generation_type="beam_search", + top_p=0.1, # keep tokens in the 1 - top_p quantile + top_k=1, # keeps the top_k most probable tokens + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + repetition_penalty=1.0, + fixed_output_length=False # if True output.shape == (batch_size, seq_len) + ): + # taking many ideas and components from HuggingFace GenerationMixin + # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation + assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." + assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" + device = image.device + + with torch.no_grad(): + sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device) + eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device) + pad_token_id = self.pad_id if pad_token_id is None else pad_token_id + logit_processor = LogitsProcessorList( + [ + MinLengthLogitsProcessor(min_seq_len, eos_token_id), + RepetitionPenaltyLogitsProcessor(repetition_penalty), + ] + ) + + if stopping_criteria is None: + stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] + stopping_criteria = StoppingCriteriaList(stopping_criteria) + + if generation_type == "beam_search": + output = self._generate_beamsearch( + image_inputs=image, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + sot_token_id=sot_token_id, + num_beams=num_beams, + num_beam_groups=num_beam_groups, + min_seq_len=min_seq_len, + stopping_criteria=stopping_criteria, + logit_processor=logit_processor, + ) + if fixed_output_length and output.shape[1] < seq_len: + pad_len = seq_len - output.shape[1] + return torch.cat(( + output, + torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id + ), + dim=1 + ) + return output + + elif generation_type == "top_p": + logit_warper = GENERATION_TYPES[generation_type](top_p) + elif generation_type == "top_k": + logit_warper = GENERATION_TYPES[generation_type](top_k) + else: + raise ValueError( + f"generation_type has to be one of " + f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." + ) + + image_latent, image_embs = self._encode_image(image) + + if text is None: + text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id + + was_training = self.training + num_dims = len(text.shape) + + if num_dims == 1: + text = text[None, :] + + self.eval() + out = text + + while True: + x = out[:, -max_seq_len:] + cur_len = x.shape[1] + logits = self( + image, + x, + image_latent=image_latent, + image_embs=image_embs, + output_labels=False, + )["logits"][:, -1] + mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) + sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id + + if mask.all(): + if not fixed_output_length: + break + else: + logits = logits[~mask, :] + filtered_logits = logit_processor(x[~mask, :], logits) + filtered_logits = logit_warper(x[~mask, :], filtered_logits) + probs = F.softmax(filtered_logits / temperature, dim=-1) + + if (cur_len + 1 == seq_len): + sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id + else: + sample[~mask, :] = torch.multinomial(probs, 1) + + out = torch.cat((out, sample), dim=-1) + + cur_len += 1 + + if all(stopping_criteria(out, None)): + break + + if num_dims == 1: + out = out.squeeze(0) + + self.train(was_training) + return out + + def _generate_beamsearch( + self, + image_inputs, + pad_token_id=None, + eos_token_id=None, + sot_token_id=None, + num_beams=6, + num_beam_groups=3, + min_seq_len=5, + stopping_criteria=None, + logit_processor=None, + logit_warper=None, + ): + device = image_inputs.device + batch_size = image_inputs.shape[0] + image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) + image_latent, image_embs = self._encode_image(image_inputs) + + input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) + input_ids = input_ids * sot_token_id + beam_scorer = BeamSearchScorer( + batch_size=batch_size, + num_beams=num_beams, + device=device, + num_beam_groups=num_beam_groups, + ) + # instantiate logits processors + logits_processor = ( + LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) + if logit_processor is None + else logit_processor + ) + + num_beams = beam_scorer.num_beams + num_beam_groups = beam_scorer.num_beam_groups + num_sub_beams = num_beams // num_beam_groups + batch_size = len(beam_scorer._beam_hyps) // num_beam_groups + batch_beam_size, cur_len = input_ids.shape + beam_indices = None + + if num_beams * batch_size != batch_beam_size: + raise ValueError( + f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." + ) + + beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) + # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in + # the same group don't produce same tokens everytime. + beam_scores[:, ::num_sub_beams] = 0 + beam_scores = beam_scores.view((batch_size * num_beams,)) + + while True: + + # predicted tokens in cur_len step + current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) + + # indices which will form the beams in the next time step + reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) + + # do one decoder step on all beams of all sentences in batch + model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) + outputs = self( + model_inputs['images'], + model_inputs['text'], + image_latent=image_latent, + image_embs=image_embs, + output_labels=False, + ) + + for beam_group_idx in range(num_beam_groups): + group_start_idx = beam_group_idx * num_sub_beams + group_end_idx = min(group_start_idx + num_sub_beams, num_beams) + group_size = group_end_idx - group_start_idx + + # indices of beams of current group among all sentences in batch + batch_group_indices = [] + + for batch_idx in range(batch_size): + batch_group_indices.extend( + [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] + ) + group_input_ids = input_ids[batch_group_indices] + + # select outputs of beams of currentg group only + next_token_logits = outputs['logits'][batch_group_indices, -1, :] + vocab_size = next_token_logits.shape[-1] + + next_token_scores_processed = logits_processor( + group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx + ) + next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) + next_token_scores = next_token_scores.expand_as(next_token_scores_processed) + + # reshape for beam search + next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) + + next_token_scores, next_tokens = torch.topk( + next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True + ) + + next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") + next_tokens = next_tokens % vocab_size + + # stateless + process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + beam_outputs = beam_scorer.process( + group_input_ids, + next_token_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + beam_indices=process_beam_indices, + group_index=beam_group_idx, + ) + beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] + beam_next_tokens = beam_outputs["next_beam_tokens"] + beam_idx = beam_outputs["next_beam_indices"] + + input_ids[batch_group_indices] = group_input_ids[beam_idx] + group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) + current_tokens[batch_group_indices] = group_input_ids[:, -1] + + # (beam_idx // group_size) -> batch_idx + # (beam_idx % group_size) -> offset of idx inside the group + reordering_indices[batch_group_indices] = ( + num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) + ) + + input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) + + # increase cur_len + cur_len = cur_len + 1 + if beam_scorer.is_done or all(stopping_criteria(input_ids, None)): + break + + final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None + sequence_outputs = beam_scorer.finalize( + input_ids, + beam_scores, + next_tokens, + next_indices, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + max_length=stopping_criteria.max_length, + beam_indices=final_beam_indices, + ) + return sequence_outputs['sequences'] + + +def create_model( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + force_preprocess_cfg: Optional[Dict[str, Any]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + require_pretrained: bool = False, + load_weights_only: bool = True, + **model_kwargs, +): + """Creates and configures a contrastive vision-language model. + + Args: + model_name: Name of the model architecture to create. Can be a local model name + or a Hugging Face model ID prefixed with 'hf-hub:'. + pretrained: Tag/path for pretrained model weights. Can be: + - A pretrained tag name (e.g., 'openai') + - A path to local weights + - None to initialize with random weights + precision: Model precision/AMP configuration. Options: + - 'fp32': 32-bit floating point + - 'fp16'/'bf16': Mixed precision with FP32 for certain layers + - 'pure_fp16'/'pure_bf16': Pure 16-bit precision + device: Device to load the model on ('cpu', 'cuda', or torch.device object) + jit: If True, JIT compile the model + force_quick_gelu: Force use of QuickGELU activation + force_custom_text: Force use of custom text encoder + force_patch_dropout: Override default patch dropout value + force_image_size: Override default image size for vision encoder + force_preprocess_cfg: Override default preprocessing configuration + pretrained_image: Load pretrained weights for timm vision models + pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights + cache_dir: Override default cache directory for downloaded model files + output_dict: If True and model supports it, return dictionary of features + require_pretrained: Raise error if pretrained weights cannot be loaded + load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety) + **model_kwargs: Additional keyword arguments passed to model constructor + + Returns: + Created and configured model instance + + Raises: + RuntimeError: If model config is not found or required pretrained weights + cannot be loaded + + Examples: + # Create basic CLIP model + model = create_model('ViT-B/32') + + # Create CLIP model with mixed precision on GPU + model = create_model('ViT-B/32', precision='fp16', device='cuda') + + # Load pretrained OpenAI weights + model = create_model('ViT-B/32', pretrained='openai') + + # Load Hugging Face model + model = create_model('hf-hub:organization/model-name') + """ + + force_preprocess_cfg = force_preprocess_cfg or {} + preprocess_cfg = asdict(PreprocessCfg()) + has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) + if has_hf_hub_prefix: + model_id = model_name[len(HF_HUB_PREFIX):] + checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) + config = _get_hf_config(model_id, cache_dir=cache_dir) + preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) + model_cfg = config['model_cfg'] + pretrained_hf = False # override, no need to load original HF text weights + else: + model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names + checkpoint_path = None + model_cfg = None + + if isinstance(device, str): + device = torch.device(device) + + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') + else: + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) + if pretrained_image: + if is_timm_model: + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True + else: + assert False, 'pretrained image towers currently only supported for timm models' + + # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + if is_hf_model: + # load pretrained weights for HF text model IFF no CLIP weights being loaded + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) + if custom_text: + if "multimodal_cfg" in model_cfg: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) + else: + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + if precision in ("fp16", "bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + # manual mixed precision that matches original OpenAI behaviour + if is_timm_model: + # FIXME this is a bit janky, create timm based model in low-precision and + # then cast only LayerNormFp32 instances back to float32 so they don't break. + # Why? The convert_weights_to_lp fn only works with native models. + model.to(device=device, dtype=dtype) + # from .transformer import LayerNormFp32 + + def _convert_ln(m): + if isinstance(m, LayerNormFp32): + m.weight.data = m.weight.data.to(torch.float32) + m.bias.data = m.bias.data.to(torch.float32) + model.apply(_convert_ln) + else: + model.to(device=device) + convert_weights_to_lp(model, dtype=dtype) + elif precision in ("pure_fp16", "pure_bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(device=device, dtype=dtype) + else: + model.to(device=device) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) + pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False) + model_quick_gelu = model_cfg.get('quick_gelu', False) + if pretrained_quick_gelu and not model_quick_gelu: + warnings.warn( + f'These pretrained weights were trained with QuickGELU activation but the model config does ' + f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.') + elif not pretrained_quick_gelu and model_quick_gelu: + warnings.warn( + f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the ' + f'model config, consider using a model config without QuickGELU or disable override flags.') + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') + load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + + if output_dict and hasattr(model, "output_dict"): + model.output_dict = True + + if jit: + model = torch.jit.script(model) + + # set image preprocessing configuration in model attributes for convenience + if getattr(model.visual, 'image_size', None) is not None: + # use image_size set on model creation (via config or force_image_size arg) + force_preprocess_cfg['size'] = model.visual.image_size + set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg)) + + return model + +def create_model_and_transforms( + model_name: str, + pretrained: Optional[str] = None, + precision: str = 'fp32', + device: Union[str, torch.device] = 'cpu', + jit: bool = False, + force_quick_gelu: bool = False, + force_custom_text: bool = False, + force_patch_dropout: Optional[float] = None, + force_image_size: Optional[Union[int, Tuple[int, int]]] = None, + image_mean: Optional[Tuple[float, ...]] = None, + image_std: Optional[Tuple[float, ...]] = None, + image_interpolation: Optional[str] = None, + image_resize_mode: Optional[str] = None, # only effective for inference + aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, + pretrained_image: bool = False, + pretrained_hf: bool = True, + cache_dir: Optional[str] = None, + output_dict: Optional[bool] = None, + load_weights_only: bool = True, + **model_kwargs, +): + force_preprocess_cfg = merge_preprocess_kwargs( + {}, + mean=image_mean, + std=image_std, + interpolation=image_interpolation, + resize_mode=image_resize_mode, + ) + + model = create_model( + model_name, + pretrained, + precision=precision, + device=device, + jit=jit, + force_quick_gelu=force_quick_gelu, + force_custom_text=force_custom_text, + force_patch_dropout=force_patch_dropout, + force_image_size=force_image_size, + force_preprocess_cfg=force_preprocess_cfg, + pretrained_image=pretrained_image, + pretrained_hf=pretrained_hf, + cache_dir=cache_dir, + output_dict=output_dict, + load_weights_only=load_weights_only, + **model_kwargs, + ) + + pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg) + + preprocess_train = image_transform_v2( + pp_cfg, + is_train=True, + aug_cfg=aug_cfg, + ) + preprocess_val = image_transform_v2( + pp_cfg, + is_train=False, + ) + + return model, preprocess_train, preprocess_val + + + +open_clip_model, open_clip_imgaug, open_clip_preprocess = create_model_and_transforms( + model_name='ViT-H-14', pretrained='laion2b_s32b_b79k', device=device +) +# print("ashish 1") +# exit() + +# ================================================================== +# C S I P - M O D U L E +# ================================================================== +class CSIP(nn.Module): + def __init__(self, image_encoder, audio_encoder, + dim_img=None, dim_audio=1024, dim_emb=1024): + super(CSIP, self).__init__() + + self.image_encoder = image_encoder # CLIPVisionModel + self.audio_encoder = audio_encoder # CLAP_audio_encoder + + for param in self.image_encoder.parameters(): + param.requires_grad = False + + # self.image_proj = nn.Linear(dim_img, dim_emb) + self.audio_proj = nn.Linear(dim_audio, dim_emb) + + # Learnable temperature parameter + # self.log_temp = nn.Parameter(torch.tensor(1/0.07).log()) + self.log_temp = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def forward(self, images, audios): + + # image_features = self.image_encoder(images) # shape: [n, dim_img] + image_features = images # shape: [n, dim_img] + audio_features = self.audio_encoder(audios)[0] # shape: [n, dim_audio] + + # Step 2: Project and normalize + image_embeds = F.normalize(image_features, dim=1) # [n, dim_emb] + audio_embeds = F.normalize(self.audio_proj(audio_features), dim=1) # [n, dim_emb] + + # Step 3: Cosine similarity with temperature + logits = torch.matmul(image_embeds, audio_embeds.T) * self.log_temp.exp() # [n, n] + probs = logits.softmax(dim=1) + + # Step 4: Symmetric cross-entropy loss + labels = torch.arange(len(images), device=images.device) + loss_i = F.cross_entropy(logits, labels) + loss_a = F.cross_entropy(logits.T, labels) + loss = (loss_i + loss_a) / 2 + + # Step 5: Similarity metric (average cosine similarity on matched pairs) + similarity_scores = (image_embeds * audio_embeds).sum(dim=1) # Cosine similarity of matching pairs + avg_similarity = similarity_scores.mean() + + return loss, loss_i, loss_a, logits, probs, avg_similarity + + + + + + + + + + + + + + + + +if __name__ == "__main__": + # ================================================================== + # I M A G E - A U D I O - D A T A S E T + # ================================================================== + class VaaniImageAudioDataset(torch.utils.data.Dataset): + def __init__(self, df, image_features_savedir, audio_tensors_savedir): + self.image_paths = df.image_path.tolist() + self.audio_paths = df.audio_path.tolist() + self.image_features_savedir = image_features_savedir + self.audio_tensors_savedir = audio_tensors_savedir + + def __len__(self): + return len(self.audio_paths) + + def __getitem__(self, idx): + return { + 'image_path': self.image_paths[idx], + 'image_feature': torch.load(os.path.join( + self.image_features_savedir, + f"{os.path.basename(self.image_paths[idx])}.pt"))['image_features'], + 'audio_path': self.audio_paths[idx], + 'audio_tensor': torch.load(os.path.join( + audio_tensors_savedir, + f"{os.path.basename(self.audio_paths[idx])}.pt"))['audio_tensor'] + } + + + train_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN3.csv") + test_df = pd.read_csv("/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST2.csv") + image_features_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_features/' + audio_tensors_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/' + train_dataset = VaaniImageAudioDataset(train_df, image_features_savedir, audio_tensors_savedir) + test_dataset = VaaniImageAudioDataset(test_df, image_features_savedir, audio_tensors_savedir) + + print('Train Dataset:', len(train_dataset)) + print('Test Dataset:', len(test_dataset)) + + + BATCH_SIZE = int(128) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=48, + pin_memory=True, + drop_last=False, + persistent_workers=True + ) + + test_dataloader = torch.utils.data.DataLoader( + test_dataset, + batch_size=BATCH_SIZE, + shuffle=False, + num_workers=48, + pin_memory=True, + drop_last=False, + persistent_workers=True + ) + + batch = next(iter(train_dataloader)) + image_features_batch = batch['image_feature'].to(device=device) + audio_tensor_batch = batch['audio_tensor'].to(device=device) + image_paths_batch = batch['image_path'] + audio_paths_batch = batch['audio_path'] + print("Image batch shape:", image_features_batch.shape) # [BATCH_SIZE, 3, 224, 224] + print("Audio batch shape:", audio_tensor_batch.shape) # [BATCH_SIZE, 1, 44100] + + + + csip_model = CSIP(open_clip_model.visual, peft_clap_audio_encoder).to(device) + # csip_model = nn.DataParallel(CSIP2(open_clip_model.visual, peft_clap_audio_encoder), device_ids=[0, 1]).to(device) + + from torchinfo import summary + import subprocess + + for param in csip_model.audio_encoder.model.projection.parameters(): + param.requires_grad = True + + summary(model=csip_model, + input_data=((image_features_batch.to(device)), (audio_tensor_batch.to(device))), + # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE), + dtypes=[torch.long], + col_names = ["trainable", "params_percent", "input_size", "output_size", "num_params"], + col_width=20, + row_settings=["var_names"], + depth = 4, + # verbose=2, + # device=device + ) + + # loss, logits, probs = csip_model(batch['image_tensor'].to(device), batch['audio_tensor'].to(device)) + # loss, logits, probs, logits.shape, probs.shape +