diff --git "a/Vaani/Img_Audio_Alignment/_2.1.2_Train_OpenCLIP.ipynb" "b/Vaani/Img_Audio_Alignment/_2.1.2_Train_OpenCLIP.ipynb" new file mode 100644--- /dev/null +++ "b/Vaani/Img_Audio_Alignment/_2.1.2_Train_OpenCLIP.ipynb" @@ -0,0 +1,7166 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "db6bc613", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cpu\n", + "Author: Ashish\n", + "\n", + "Last updated: 2025-05-27T15:47:36.188763+05:30\n", + "\n", + "Python implementation: CPython\n", + "Python version : 3.11.11\n", + "IPython version : 9.1.0\n", + "\n", + "conda environment: clap\n", + "\n", + "Compiler : GCC 11.2.0\n", + "OS : Linux\n", + "Release : 4.18.0-513.5.1.el8_9.x86_64\n", + "Machine : x86_64\n", + "Processor : x86_64\n", + "CPU cores : 48\n", + "Architecture: 64bit\n", + "\n", + "Hostname: login01\n", + "\n", + "huggingface_hub: 0.31.2\n", + "torch : 2.1.2\n", + "torchlibrosa : 0.1.0\n", + "seaborn : 0.13.2\n", + "sys : 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0]\n", + "torchvision : 0.16.2\n", + "watermark : 2.5.0\n", + "re : 2.2.1\n", + "pandas : 2.2.3\n", + "tqdm : 4.67.1\n", + "PIL : 11.1.0\n", + "transformers : 4.51.3\n", + "peft : 0.15.2\n", + "torchaudio : 2.1.2\n", + "matplotlib : 3.10.1\n", + "numpy : 1.26.0\n", + "csv : 1.0\n", + "\n", + "GPU Info: NVIDIA drivers do not appear to be installed on this machine.\n", + "\n", + "trainable params: 1,289,494 || all params: 34,355,927 || trainable%: 3.7533\n", + "Train Dataset: 62265\n", + "Test Dataset: 11490\n", + "Image batch shape: torch.Size([128, 1024])\n", + "Audio batch shape: torch.Size([128, 308700])\n" + ] + } + ], + "source": [ + "# ==================================================================\n", + "# L A T E N T D I F F U S I O N M O D E L\n", + "# ==================================================================\n", + "# Author : Ashish Kumar Uchadiya\n", + "# Created : May 11, 2025\n", + "# Description: This script implements the training of a VQ-VAE model for\n", + "# image reconstruction, integrated with Latent Diffusion Models (LDMs) and\n", + "# audio conditioning. The VQ-VAE maps images to a discrete latent space, \n", + "# which is then modeled by the LDM for learning a diffusion process over the \n", + "# compressed representation. Audio features are used as conditioning inputs \n", + "# to guide the generation process. The training minimizes a combination of \n", + "# LPIPS (Learned Perceptual Image Patch Similarity) loss for perceptual \n", + "# fidelity and PatchGAN loss to enforce local realism. This setup enables \n", + "# efficient and semantically-aware generation of high-quality images driven \n", + "# by audio cues.\n", + "# ==================================================================\n", + "# I M P O R T S\n", + "# ==================================================================\n", + "from __future__ import annotations\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "import os\n", + "import io\n", + "import sys\n", + "import math\n", + "import random\n", + "import collections\n", + "import collections.abc\n", + "import re\n", + "from itertools import repeat\n", + "from pathlib import Path\n", + "from typing import Optional, Tuple, Union, List, Dict\n", + "\n", + "import csv\n", + "import copy\n", + "import numpy as np\n", + "import pandas as pd\n", + "from PIL import Image\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "from tqdm import trange, tqdm\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import nn\n", + "from torch.nn.init import _calculate_fan_in_and_fan_out\n", + "import torch.utils.checkpoint as checkpoint\n", + "\n", + "import torchvision\n", + "from torchvision.transforms import v2\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "# from tensorboardX import SummaryWriter\n", + "\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using device: {device}\")\n", + "\n", + "import torchaudio\n", + "import torchaudio.transforms as T\n", + "from torchlibrosa.stft import Spectrogram, LogmelFilterBank\n", + "from torchlibrosa.augmentation import SpecAugmentation\n", + "\n", + "from transformers import AutoModel, AutoTokenizer, logging\n", + "from huggingface_hub.file_download import hf_hub_download\n", + "from huggingface_hub.file_download import hf_hub_download\n", + "from peft import get_peft_config, get_peft_model\n", + "from transformers import CLIPVisionModel, AutoProcessor\n", + "\n", + "from watermark import watermark\n", + "print(watermark(\n", + " author='Ashish',\n", + " # email='ashish@example.com',\n", + " current_date=True,\n", + " datename=True,\n", + " current_time=True,\n", + " iso8601=True,\n", + " timezone=True,\n", + " updated=True,\n", + " custom_time=None,\n", + " python=True,\n", + " # packages=\"torch,torchvision,numpy\",\n", + " conda=True,\n", + " hostname=True,\n", + " machine=True,\n", + " watermark=False,\n", + " iversions=True,\n", + " gpu=True,\n", + " globals_=globals()\n", + "))\n", + "\n", + "\n", + "# ==================================================================\n", + "# H T S - A T\n", + "# ==================================================================\n", + "class HTSATConfig:\n", + " # Ke Chen\n", + " # knutchen@ucsd.edu\n", + " # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION\n", + " # The configuration for training the model\n", + "\n", + " exp_name = \"exp_htsat_pretrain\" # the saved ckpt prefix name of the model \n", + " workspace = \"/home/kechen/Research/HTSAT\" # the folder of your code\n", + " dataset_path = \"/home/Research/audioset\" # the dataset path\n", + " desed_folder = \"/home/Research/DESED\" # the desed file\n", + "\n", + " dataset_type = \"audioset\" # \"audioset\" \"esc-50\" \"scv2\"\n", + " index_type = \"full_train\" # only works for audioset\n", + " balanced_data = True # only works for audioset\n", + "\n", + " loss_type = \"clip_bce\" # \n", + " # AudioSet & SCV2: \"clip_bce\" | ESC-50: \"clip_ce\" \n", + "\n", + " # trained from a checkpoint, or evaluate a single model \n", + " resume_checkpoint = None \n", + " # \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt\"\n", + " \n", + " esc_fold = 0 # just for esc dataset, select the fold you need for evaluation and (+1) validation\n", + "\n", + "\n", + " debug = False\n", + "\n", + " random_seed = 970131 # 19970318 970131 12412 127777 1009 34047\n", + " batch_size = 32 * 4 # batch size per GPU x GPU number , default is 32 x 4 = 128\n", + " learning_rate = 1e-3 # 1e-4 also workable \n", + " max_epoch = 100\n", + " num_workers = 3\n", + "\n", + " lr_scheduler_epoch = [10,20,30]\n", + " lr_rate = [0.02, 0.05, 0.1]\n", + "\n", + " # these data preparation optimizations do not bring many improvements, so deprecated\n", + " enable_token_label = False # token label\n", + " class_map_path = \"class_hier_map.npy\"\n", + " class_filter = None \n", + " retrieval_index = [15382, 9202, 130, 17618, 17157, 17516, 16356, 6165, 13992, 9238, 5550, 5733, 1914, 1600, 3450, 13735, 11108, 3762, \n", + " 9840, 11318, 8131, 4429, 16748, 4992, 16783, 12691, 4945, 8779, 2805, 9418, 2797, 14357, 5603, 212, 3852, 12666, 1338, 10269, 2388, 8260, 4293, 14454, 7677, 11253, 5060, 14938, 8840, 4542, 2627, 16336, 8992, 15496, 11140, 446, 6126, 10691, 8624, 10127, 9068, 16710, 10155, 14358, 7567, 5695, 2354, 8057, 17635, 133, 16183, 14535, 7248, 4560, 14429, 2463, 10773, 113, 2462, 9223, 4929, 14274, 4716, 17307, 4617, 2132, 11083, 1039, 1403, 9621, 13936, 2229, 2875, 17840, 9359, 13311, 9790, 13288, 4750, 17052, 8260, 14900]\n", + " token_label_range = [0.2,0.6]\n", + " enable_time_shift = False # shift time\n", + " enable_label_enhance = False # enhance hierarchical label\n", + " enable_repeat_mode = False # repeat the spectrogram / reshape the spectrogram\n", + "\n", + "\n", + "\n", + " # for model's design\n", + " enable_tscam = True # enbale the token-semantic layer\n", + "\n", + " # for signal processing\n", + " sample_rate = 32000 # 16000 for scv2, 32000 for audioset and esc-50\n", + " clip_samples = sample_rate * 10 # audio_set 10-sec clip\n", + " window_size = 1024\n", + " hop_size = 320 # 160 for scv2, 320 for audioset and esc-50\n", + " mel_bins = 64\n", + " fmin = 50\n", + " fmax = 14000\n", + " shift_max = int(clip_samples * 0.5)\n", + "\n", + " # for data collection\n", + " classes_num = 527 # esc: 50 | audioset: 527 | scv2: 35\n", + " patch_size = (25, 4) # deprecated\n", + " crop_size = None # int(clip_samples * 0.5) deprecated\n", + "\n", + " # for htsat hyperparamater\n", + " htsat_window_size = 8\n", + " htsat_spec_size = 256\n", + " htsat_patch_size = 4 \n", + " htsat_stride = (4, 4)\n", + " htsat_num_head = [4,8,16,32]\n", + " htsat_dim = 96 \n", + " htsat_depth = [2,2,6,2]\n", + "\n", + " swin_pretrain_path = None\n", + " # \"/home/Research/model_backup/pretrain/swin_tiny_c24_patch4_window8_256.pth\"\n", + "\n", + " # Some Deprecated Optimization in the model design, check the model code for details\n", + " htsat_attn_heatmap = False\n", + " htsat_hier_output = False \n", + " htsat_use_max = False\n", + "\n", + "\n", + " # for ensemble test \n", + "\n", + " ensemble_checkpoints = []\n", + " ensemble_strides = []\n", + "\n", + "\n", + " # weight average folder\n", + " wa_folder = \"/home/version_0/checkpoints/\"\n", + " # weight average output filename\n", + " wa_model_path = \"HTSAT_AudioSet_Saved_x.ckpt\"\n", + "\n", + " esm_model_pathes = [\n", + " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_1.ckpt\",\n", + " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_2.ckpt\",\n", + " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_3.ckpt\",\n", + " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_4.ckpt\",\n", + " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_5.ckpt\",\n", + " \"/home/Research/model_backup/AudioSet/HTSAT_AudioSet_Saved_6.ckpt\"\n", + " ]\n", + "\n", + " # for framewise localization\n", + " heatmap_dir = \"/home/Research/heatmap_output\"\n", + " test_file = \"htsat-test-ensemble\"\n", + " fl_local = False # indicate if we need to use this dataset for the framewise detection\n", + " fl_dataset = \"/home/Research/desed/desedim_embval.npy\" \n", + " fl_class_num = [\n", + " \"Speech\", \"Frying\", \"Dishes\", \"Running_water\",\n", + " \"Blender\", \"Electric_shaver_toothbrush\", \"Alarm_bell_ringing\",\n", + " \"Cat\", \"Dog\", \"Vacuum_cleaner\"\n", + " ]\n", + "\n", + " # map 527 classes into 10 classes\n", + " fl_audioset_mapping = [\n", + " [0,1,2,3,4,5,6,7],\n", + " [366, 367, 368],\n", + " [364],\n", + " [288, 289, 290, 291, 292, 293, 294, 295, 296, 297],\n", + " [369],\n", + " [382],\n", + " [310, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402],\n", + " [81, 82, 83, 84, 85],\n", + " [74, 75, 76, 77, 78, 79],\n", + " [377]\n", + " ]\n", + "\n", + "\n", + "\n", + "def _ntuple(n):\n", + " def parse(x):\n", + " if isinstance(x, collections.abc.Iterable):\n", + " return x\n", + " return tuple(repeat(x, n))\n", + " return parse\n", + "\n", + "to_1tuple = _ntuple(1)\n", + "to_2tuple = _ntuple(2)\n", + "to_3tuple = _ntuple(3)\n", + "to_4tuple = _ntuple(4)\n", + "to_ntuple = _ntuple\n", + "\n", + "def do_mixup(x, mixup_lambda):\n", + " \"\"\"Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes \n", + " (1, 3, 5, ...).\n", + " Args:\n", + " x: (batch_size * 2, ...)\n", + " mixup_lambda: (batch_size * 2,)\n", + " Returns:\n", + " out: (batch_size, ...)\n", + " \"\"\"\n", + " out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \\\n", + " x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1)\n", + " return out\n", + "\n", + "def interpolate(x, ratio):\n", + " \"\"\"Interpolate data in time domain. This is used to compensate the \n", + " resolution reduction in downsampling of a CNN.\n", + " \n", + " Args:\n", + " x: (batch_size, time_steps, classes_num)\n", + " ratio: int, ratio to interpolate\n", + " Returns:\n", + " upsampled: (batch_size, time_steps * ratio, classes_num)\n", + " \"\"\"\n", + " (batch_size, time_steps, classes_num) = x.shape\n", + " upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)\n", + " upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)\n", + " return upsampled\n", + "\n", + "\n", + "def drop_path(x, drop_prob: float = 0., training: bool = False):\n", + " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n", + " This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n", + " the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n", + " See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n", + " changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n", + " 'survival rate' as the argument.\n", + " \"\"\"\n", + " if drop_prob == 0. or not training:\n", + " return x\n", + " keep_prob = 1 - drop_prob\n", + " shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets\n", + " random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)\n", + " random_tensor.floor_() # binarize\n", + " output = x.div(keep_prob) * random_tensor\n", + " return output\n", + "\n", + "\n", + "class DropPath(nn.Module):\n", + " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n", + " \"\"\"\n", + " def __init__(self, drop_prob=None):\n", + " super(DropPath, self).__init__()\n", + " self.drop_prob = drop_prob\n", + "\n", + " def forward(self, x):\n", + " return drop_path(x, self.drop_prob, self.training)\n", + "\n", + "class PatchEmbed(nn.Module):\n", + " \"\"\" 2D Image to Patch Embedding\n", + " \"\"\"\n", + " def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, patch_stride = 16):\n", + " super().__init__()\n", + " img_size = to_2tuple(img_size)\n", + " patch_size = to_2tuple(patch_size)\n", + " patch_stride = to_2tuple(patch_stride)\n", + " self.img_size = img_size\n", + " self.patch_size = patch_size\n", + " self.patch_stride = patch_stride\n", + " self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1])\n", + " self.num_patches = self.grid_size[0] * self.grid_size[1]\n", + " self.flatten = flatten\n", + " self.in_chans = in_chans\n", + " self.embed_dim = embed_dim\n", + " \n", + " padding = ((patch_size[0] - patch_stride[0]) // 2, (patch_size[1] - patch_stride[1]) // 2)\n", + "\n", + " self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=padding)\n", + " self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n", + "\n", + " def forward(self, x):\n", + " B, C, H, W = x.shape\n", + " assert H == self.img_size[0] and W == self.img_size[1], \\\n", + " f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n", + " x = self.proj(x)\n", + " if self.flatten:\n", + " x = x.flatten(2).transpose(1, 2) # BCHW -> BNC\n", + " x = self.norm(x)\n", + " return x\n", + "\n", + "class Mlp(nn.Module):\n", + " \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\n", + " \"\"\"\n", + " def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n", + " super().__init__()\n", + " out_features = out_features or in_features\n", + " hidden_features = hidden_features or in_features\n", + " self.fc1 = nn.Linear(in_features, hidden_features)\n", + " self.act = act_layer()\n", + " self.fc2 = nn.Linear(hidden_features, out_features)\n", + " self.drop = nn.Dropout(drop)\n", + "\n", + " def forward(self, x):\n", + " x = self.fc1(x)\n", + " x = self.act(x)\n", + " x = self.drop(x)\n", + " x = self.fc2(x)\n", + " x = self.drop(x)\n", + " return x\n", + "\n", + "def _no_gradim_audiorunc_normal_(tensor, mean, std, a, b):\n", + " # Cut & paste from PyTorch official master until it's in a few official releases - RW\n", + " # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n", + " def norm_cdf(x):\n", + " # Computes standard normal cumulative distribution function\n", + " return (1. + math.erf(x / math.sqrt(2.))) / 2.\n", + "\n", + " if (mean < a - 2 * std) or (mean > b + 2 * std):\n", + " warnings.warn(\"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n", + " \"The distribution of values may be incorrect.\",\n", + " stacklevel=2)\n", + "\n", + " with torch.no_grad():\n", + " # Values are generated by using a truncated uniform distribution and\n", + " # then using the inverse CDF for the normal distribution.\n", + " # Get upper and lower cdf values\n", + " l = norm_cdf((a - mean) / std)\n", + " u = norm_cdf((b - mean) / std)\n", + "\n", + " # Uniformly fill tensor with values from [l, u], then translate to\n", + " # [2l-1, 2u-1].\n", + " tensor.uniform_(2 * l - 1, 2 * u - 1)\n", + "\n", + " # Use inverse cdf transform for normal distribution to get truncated\n", + " # standard normal\n", + " tensor.erfinv_()\n", + "\n", + " # Transform to proper mean, std\n", + " tensor.mul_(std * math.sqrt(2.))\n", + " tensor.add_(mean)\n", + "\n", + " # Clamp to ensure it's in the proper range\n", + " tensor.clamp_(min=a, max=b)\n", + " return tensor\n", + "\n", + "\n", + "def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):\n", + " # type: (Tensor, float, float, float, float) -> Tensor\n", + " r\"\"\"Fills the input Tensor with values drawn from a truncated\n", + " normal distribution. The values are effectively drawn from the\n", + " normal distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n", + " with values outside :math:`[a, b]` redrawn until they are within\n", + " the bounds. The method used for generating the random values works\n", + " best when :math:`a \\leq \\text{mean} \\leq b`.\n", + " Args:\n", + " tensor: an n-dimensional `torch.Tensor`\n", + " mean: the mean of the normal distribution\n", + " std: the standard deviation of the normal distribution\n", + " a: the minimum cutoff value\n", + " b: the maximum cutoff value\n", + " Examples:\n", + " >>> w = torch.empty(3, 5)\n", + " >>> nn.init.trunc_normal_(w)\n", + " \"\"\"\n", + " return _no_gradim_audiorunc_normal_(tensor, mean, std, a, b)\n", + "\n", + "\n", + "def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):\n", + " fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)\n", + " if mode == 'fan_in':\n", + " denom = fan_in\n", + " elif mode == 'fan_out':\n", + " denom = fan_out\n", + " elif mode == 'fan_avg':\n", + " denom = (fan_in + fan_out) / 2\n", + "\n", + " variance = scale / denom\n", + "\n", + " if distribution == \"truncated_normal\":\n", + " # constant is stddev of standard normal truncated to (-2, 2)\n", + " trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)\n", + " elif distribution == \"normal\":\n", + " tensor.normal_(std=math.sqrt(variance))\n", + " elif distribution == \"uniform\":\n", + " bound = math.sqrt(3 * variance)\n", + " tensor.uniform_(-bound, bound)\n", + " else:\n", + " raise ValueError(f\"invalid distribution {distribution}\")\n", + "\n", + "\n", + "def lecun_normal_(tensor):\n", + " variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')\n", + "\n", + "\n", + "# below codes are based and referred from https://github.com/microsoft/Swin-Transformer\n", + "# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf\n", + "\n", + "def window_partition(x, window_size):\n", + " \"\"\"\n", + " Args:\n", + " x: (B, H, W, C)\n", + " window_size (int): window size\n", + " Returns:\n", + " windows: (num_windows*B, window_size, window_size, C)\n", + " \"\"\"\n", + " B, H, W, C = x.shape\n", + " x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)\n", + " windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)\n", + " return windows\n", + "\n", + "\n", + "def window_reverse(windows, window_size, H, W):\n", + " \"\"\"\n", + " Args:\n", + " windows: (num_windows*B, window_size, window_size, C)\n", + " window_size (int): Window size\n", + " H (int): Height of image\n", + " W (int): Width of image\n", + " Returns:\n", + " x: (B, H, W, C)\n", + " \"\"\"\n", + " B = int(windows.shape[0] / (H * W / window_size / window_size))\n", + " x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)\n", + " x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)\n", + " return x\n", + "\n", + "\n", + "class WindowAttention(nn.Module):\n", + " r\"\"\" Window based multi-head self attention (W-MSA) module with relative position bias.\n", + " It supports both of shifted and non-shifted window.\n", + " Args:\n", + " dim (int): Number of input channels.\n", + " window_size (tuple[int]): The height and width of the window.\n", + " num_heads (int): Number of attention heads.\n", + " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n", + " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set\n", + " attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0\n", + " proj_drop (float, optional): Dropout ratio of output. Default: 0.0\n", + " \"\"\"\n", + "\n", + " def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):\n", + "\n", + " super().__init__()\n", + " self.dim = dim\n", + " self.window_size = window_size # Wh, Ww\n", + " self.num_heads = num_heads\n", + " head_dim = dim // num_heads\n", + " self.scale = qk_scale or head_dim ** -0.5\n", + "\n", + " # define a parameter table of relative position bias\n", + " self.relative_position_bias_table = nn.Parameter(\n", + " torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH\n", + "\n", + " # get pair-wise relative position index for each token inside the window\n", + " coords_h = torch.arange(self.window_size[0])\n", + " coords_w = torch.arange(self.window_size[1])\n", + " coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww\n", + " coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww\n", + " relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww\n", + " relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2\n", + " relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0\n", + " relative_coords[:, :, 1] += self.window_size[1] - 1\n", + " relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1\n", + " relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww\n", + " self.register_buffer(\"relative_position_index\", relative_position_index)\n", + "\n", + " self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n", + " self.attn_drop = nn.Dropout(attn_drop)\n", + " self.proj = nn.Linear(dim, dim)\n", + " self.proj_drop = nn.Dropout(proj_drop)\n", + "\n", + " trunc_normal_(self.relative_position_bias_table, std=.02)\n", + " self.softmax = nn.Softmax(dim=-1)\n", + "\n", + " def forward(self, x, mask=None):\n", + " \"\"\"\n", + " Args:\n", + " x: input features with shape of (num_windows*B, N, C)\n", + " mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None\n", + " \"\"\"\n", + " B_, N, C = x.shape\n", + " qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n", + " q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)\n", + "\n", + " q = q * self.scale\n", + " attn = (q @ k.transpose(-2, -1))\n", + "\n", + " relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(\n", + " self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH\n", + " relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww\n", + " attn = attn + relative_position_bias.unsqueeze(0)\n", + "\n", + " if mask is not None:\n", + " nW = mask.shape[0]\n", + " attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)\n", + " attn = attn.view(-1, self.num_heads, N, N)\n", + " attn = self.softmax(attn)\n", + " else:\n", + " attn = self.softmax(attn)\n", + "\n", + " attn = self.attn_drop(attn)\n", + "\n", + " x = (attn @ v).transpose(1, 2).reshape(B_, N, C)\n", + " x = self.proj(x)\n", + " x = self.proj_drop(x)\n", + " return x, attn\n", + "\n", + " def extra_repr(self):\n", + " return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'\n", + "\n", + "\n", + "# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model\n", + "class SwinTransformerBlock(nn.Module):\n", + " r\"\"\" Swin Transformer Block.\n", + " Args:\n", + " dim (int): Number of input channels.\n", + " input_resolution (tuple[int]): Input resulotion.\n", + " num_heads (int): Number of attention heads.\n", + " window_size (int): Window size.\n", + " shift_size (int): Shift size for SW-MSA.\n", + " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n", + " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n", + " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n", + " drop (float, optional): Dropout rate. Default: 0.0\n", + " attn_drop (float, optional): Attention dropout rate. Default: 0.0\n", + " drop_path (float, optional): Stochastic depth rate. Default: 0.0\n", + " act_layer (nn.Module, optional): Activation layer. Default: nn.GELU\n", + " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n", + " \"\"\"\n", + "\n", + " def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,\n", + " mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,\n", + " act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_before_mlp='ln'):\n", + " super().__init__()\n", + " self.dim = dim\n", + " self.input_resolution = input_resolution\n", + " self.num_heads = num_heads\n", + " self.window_size = window_size\n", + " self.shift_size = shift_size\n", + " self.mlp_ratio = mlp_ratio\n", + " self.norm_before_mlp = norm_before_mlp\n", + " if min(self.input_resolution) <= self.window_size:\n", + " # if window size is larger than input resolution, we don't partition windows\n", + " self.shift_size = 0\n", + " self.window_size = min(self.input_resolution)\n", + " assert 0 <= self.shift_size < self.window_size, \"shift_size must in 0-window_size\"\n", + "\n", + " self.norm1 = norm_layer(dim)\n", + " self.attn = WindowAttention(\n", + " dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,\n", + " qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n", + "\n", + " self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n", + " if self.norm_before_mlp == 'ln':\n", + " self.norm2 = nn.LayerNorm(dim)\n", + " elif self.norm_before_mlp == 'bn':\n", + " self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(1, 2)\n", + " else:\n", + " raise NotImplementedError\n", + " mlp_hidden_dim = int(dim * mlp_ratio)\n", + " self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n", + "\n", + " if self.shift_size > 0:\n", + " # calculate attention mask for SW-MSA\n", + " H, W = self.input_resolution\n", + " img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1\n", + " h_slices = (slice(0, -self.window_size),\n", + " slice(-self.window_size, -self.shift_size),\n", + " slice(-self.shift_size, None))\n", + " w_slices = (slice(0, -self.window_size),\n", + " slice(-self.window_size, -self.shift_size),\n", + " slice(-self.shift_size, None))\n", + " cnt = 0\n", + " for h in h_slices:\n", + " for w in w_slices:\n", + " img_mask[:, h, w, :] = cnt\n", + " cnt += 1\n", + "\n", + " mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1\n", + " mask_windows = mask_windows.view(-1, self.window_size * self.window_size)\n", + " attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)\n", + " attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))\n", + " else:\n", + " attn_mask = None\n", + "\n", + " self.register_buffer(\"attn_mask\", attn_mask)\n", + "\n", + " def forward(self, x):\n", + " # pdb.set_trace()\n", + " H, W = self.input_resolution\n", + " # print(\"H: \", H)\n", + " # print(\"W: \", W)\n", + " # pdb.set_trace()\n", + " B, L, C = x.shape\n", + " # assert L == H * W, \"input feature has wrong size\"\n", + "\n", + " shortcut = x\n", + " x = self.norm1(x)\n", + " x = x.view(B, H, W, C)\n", + "\n", + " # cyclic shift\n", + " if self.shift_size > 0:\n", + " shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))\n", + " else:\n", + " shifted_x = x\n", + "\n", + " # partition windows\n", + " x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C\n", + " x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C\n", + "\n", + " # W-MSA/SW-MSA\n", + " attn_windows, attn = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C\n", + "\n", + " # merge windows\n", + " attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)\n", + " shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C\n", + "\n", + " # reverse cyclic shift\n", + " if self.shift_size > 0:\n", + " x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))\n", + " else:\n", + " x = shifted_x\n", + " x = x.view(B, H * W, C)\n", + "\n", + " # FFN\n", + " x = shortcut + self.drop_path(x)\n", + " x = x + self.drop_path(self.mlp(self.norm2(x)))\n", + "\n", + " return x, attn\n", + "\n", + " def extra_repr(self):\n", + " return f\"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, \" \\\n", + " f\"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}\"\n", + "\n", + "\n", + "\n", + "class PatchMerging(nn.Module):\n", + " r\"\"\" Patch Merging Layer.\n", + " Args:\n", + " input_resolution (tuple[int]): Resolution of input feature.\n", + " dim (int): Number of input channels.\n", + " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):\n", + " super().__init__()\n", + " self.input_resolution = input_resolution\n", + " self.dim = dim\n", + " self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)\n", + " self.norm = norm_layer(4 * dim)\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\n", + " x: B, H*W, C\n", + " \"\"\"\n", + " H, W = self.input_resolution\n", + " B, L, C = x.shape\n", + " assert L == H * W, \"input feature has wrong size\"\n", + " assert H % 2 == 0 and W % 2 == 0, f\"x size ({H}*{W}) are not even.\"\n", + "\n", + " x = x.view(B, H, W, C)\n", + "\n", + " x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C\n", + " x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C\n", + " x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C\n", + " x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C\n", + " x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C\n", + " x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C\n", + "\n", + " x = self.norm(x)\n", + " x = self.reduction(x)\n", + "\n", + " return x\n", + "\n", + " def extra_repr(self):\n", + " return f\"input_resolution={self.input_resolution}, dim={self.dim}\"\n", + "\n", + "\n", + "class BasicLayer(nn.Module):\n", + " \"\"\" A basic Swin Transformer layer for one stage.\n", + " Args:\n", + " dim (int): Number of input channels.\n", + " input_resolution (tuple[int]): Input resolution.\n", + " depth (int): Number of blocks.\n", + " num_heads (int): Number of attention heads.\n", + " window_size (int): Local window size.\n", + " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.\n", + " qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True\n", + " qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.\n", + " drop (float, optional): Dropout rate. Default: 0.0\n", + " attn_drop (float, optional): Attention dropout rate. Default: 0.0\n", + " drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0\n", + " norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm\n", + " downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None\n", + " use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.\n", + " \"\"\"\n", + "\n", + " def __init__(self, dim, input_resolution, depth, num_heads, window_size,\n", + " mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,\n", + " drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,\n", + " norm_before_mlp='ln'):\n", + "\n", + " super().__init__()\n", + " self.dim = dim\n", + " self.input_resolution = input_resolution\n", + " self.depth = depth\n", + " self.use_checkpoint = use_checkpoint\n", + "\n", + " # build blocks\n", + " self.blocks = nn.ModuleList([\n", + " SwinTransformerBlock(dim=dim, input_resolution=input_resolution,\n", + " num_heads=num_heads, window_size=window_size,\n", + " shift_size=0 if (i % 2 == 0) else window_size // 2,\n", + " mlp_ratio=mlp_ratio,\n", + " qkv_bias=qkv_bias, qk_scale=qk_scale,\n", + " drop=drop, attn_drop=attn_drop,\n", + " drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\n", + " norm_layer=norm_layer, norm_before_mlp=norm_before_mlp)\n", + " for i in range(depth)])\n", + "\n", + " # patch merging layer\n", + " if downsample is not None:\n", + " self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)\n", + " else:\n", + " self.downsample = None\n", + "\n", + " def forward(self, x):\n", + " attns = []\n", + " for blk in self.blocks:\n", + " if self.use_checkpoint:\n", + " x = checkpoint.checkpoint(blk, x)\n", + " else:\n", + " x, attn = blk(x)\n", + " if not self.training:\n", + " attns.append(attn.unsqueeze(0))\n", + " if self.downsample is not None:\n", + " x = self.downsample(x)\n", + " if not self.training:\n", + " attn = torch.cat(attns, dim = 0)\n", + " attn = torch.mean(attn, dim = 0)\n", + " return x, attn\n", + "\n", + " def extra_repr(self):\n", + " return f\"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}\"\n", + "\n", + "\n", + "# The Core of HTSAT\n", + "class HTSAT_Swin_Transformer(nn.Module):\n", + " r\"\"\"HTSAT based on the Swin Transformer\n", + " Args:\n", + " spec_size (int | tuple(int)): Input Spectrogram size. Default 256\n", + " patch_size (int | tuple(int)): Patch size. Default: 4\n", + " path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4\n", + " in_chans (int): Number of input image channels. Default: 1 (mono)\n", + " num_classes (int): Number of classes for classification head. Default: 527\n", + " embed_dim (int): Patch embedding dimension. Default: 96\n", + " depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.\n", + " num_heads (tuple(int)): Number of attention heads in different layers.\n", + " window_size (int): Window size. Default: 8\n", + " mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4\n", + " qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True\n", + " qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None\n", + " drop_rate (float): Dropout rate. Default: 0\n", + " attn_drop_rate (float): Attention dropout rate. Default: 0\n", + " drop_path_rate (float): Stochastic depth rate. Default: 0.1\n", + " norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.\n", + " ape (bool): If True, add absolute position embedding to the patch embedding. Default: False\n", + " patch_norm (bool): If True, add normalization after patch embedding. Default: True\n", + " use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False\n", + " config (module): The configuration Module from config.py (HTSATConfig Class)\n", + " \"\"\"\n", + "\n", + " def __init__(self, spec_size=256, patch_size=4, patch_stride=(4,4), \n", + " in_chans=1, num_classes=527,\n", + " embed_dim=96, depths=[2, 2, 6, 2], num_heads=[4, 8, 16, 32],\n", + " window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,\n", + " drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n", + " norm_layer=nn.LayerNorm, \n", + " ape=False, patch_norm=True,\n", + " use_checkpoint=False, norm_before_mlp='ln', config = None, **kwargs):\n", + " super(HTSAT_Swin_Transformer, self).__init__()\n", + "\n", + " self.config = config\n", + " self.spec_size = spec_size \n", + " self.patch_stride = patch_stride\n", + " self.patch_size = patch_size\n", + " self.window_size = window_size\n", + " self.embed_dim = embed_dim\n", + " self.depths = depths\n", + " self.ape = ape\n", + " self.in_chans = in_chans\n", + " self.num_classes = num_classes\n", + " self.num_heads = num_heads\n", + " self.num_layers = len(self.depths)\n", + " self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))\n", + " \n", + " self.drop_rate = drop_rate\n", + " self.attn_drop_rate = attn_drop_rate\n", + " self.drop_path_rate = drop_path_rate\n", + "\n", + " self.qkv_bias = qkv_bias\n", + " self.qk_scale = None\n", + "\n", + " self.patch_norm = patch_norm\n", + " self.norm_layer = norm_layer if self.patch_norm else None\n", + " self.norm_before_mlp = norm_before_mlp\n", + " self.mlp_ratio = mlp_ratio\n", + "\n", + " self.use_checkpoint = use_checkpoint\n", + "\n", + " # process mel-spec ; used only once\n", + " self.freq_ratio = self.spec_size // self.config.mel_bins\n", + " window = 'hann'\n", + " center = True\n", + " pad_mode = 'reflect'\n", + " ref = 1.0\n", + " amin = 1e-10\n", + " top_db = None\n", + " self.interpolate_ratio = 32 # Downsampled ratio\n", + " # Spectrogram extractor\n", + " self.spectrogram_extractor = Spectrogram(n_fft=config.window_size, hop_length=config.hop_size, \n", + " win_length=config.window_size, window=window, center=center, pad_mode=pad_mode, \n", + " freeze_parameters=True)\n", + " # Logmel feature extractor\n", + " self.logmel_extractor = LogmelFilterBank(sr=config.sample_rate, n_fft=config.window_size, \n", + " n_mels=config.mel_bins, fmin=config.fmin, fmax=config.fmax, ref=ref, amin=amin, top_db=top_db, \n", + " freeze_parameters=True)\n", + " # Spec augmenter\n", + " self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, \n", + " freq_drop_width=8, freq_stripes_num=2) # 2 2\n", + " self.bn0 = nn.BatchNorm2d(self.config.mel_bins)\n", + "\n", + "\n", + " # split spctrogram into non-overlapping patches\n", + " self.patch_embed = PatchEmbed(\n", + " img_size=self.spec_size, patch_size=self.patch_size, in_chans=self.in_chans, \n", + " embed_dim=self.embed_dim, norm_layer=self.norm_layer, patch_stride = patch_stride)\n", + "\n", + " num_patches = self.patch_embed.num_patches\n", + " patches_resolution = self.patch_embed.grid_size\n", + " self.patches_resolution = patches_resolution\n", + "\n", + " # absolute position embedding\n", + " if self.ape:\n", + " self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.embed_dim))\n", + " trunc_normal_(self.absolute_pos_embed, std=.02)\n", + "\n", + " self.pos_drop = nn.Dropout(p=self.drop_rate)\n", + "\n", + " # stochastic depth\n", + " dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))] # stochastic depth decay rule\n", + "\n", + " # build layers\n", + " self.layers = nn.ModuleList()\n", + " for i_layer in range(self.num_layers):\n", + " layer = BasicLayer(dim=int(self.embed_dim * 2 ** i_layer),\n", + " input_resolution=(patches_resolution[0] // (2 ** i_layer),\n", + " patches_resolution[1] // (2 ** i_layer)),\n", + " depth=self.depths[i_layer],\n", + " num_heads=self.num_heads[i_layer],\n", + " window_size=self.window_size,\n", + " mlp_ratio=self.mlp_ratio,\n", + " qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,\n", + " drop=self.drop_rate, attn_drop=self.attn_drop_rate,\n", + " drop_path=dpr[sum(self.depths[:i_layer]):sum(self.depths[:i_layer + 1])],\n", + " norm_layer=self.norm_layer,\n", + " downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,\n", + " use_checkpoint=use_checkpoint,\n", + " norm_before_mlp=self.norm_before_mlp)\n", + " self.layers.append(layer)\n", + "\n", + " self.norm = self.norm_layer(self.num_features)\n", + " self.avgpool = nn.AdaptiveAvgPool1d(1)\n", + " self.maxpool = nn.AdaptiveMaxPool1d(1)\n", + "\n", + " if self.config.enable_tscam:\n", + " SF = self.spec_size // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] // self.freq_ratio\n", + " self.tscam_conv = nn.Conv2d(\n", + " in_channels = self.num_features,\n", + " out_channels = self.num_classes,\n", + " kernel_size = (SF,3),\n", + " padding = (0,1)\n", + " )\n", + " self.head = nn.Linear(num_classes, num_classes)\n", + " else:\n", + " self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()\n", + "\n", + " self.apply(self._init_weights)\n", + "\n", + " def _init_weights(self, m):\n", + " if isinstance(m, nn.Linear):\n", + " trunc_normal_(m.weight, std=.02)\n", + " if isinstance(m, nn.Linear) and m.bias is not None:\n", + " nn.init.constant_(m.bias, 0)\n", + " elif isinstance(m, nn.LayerNorm):\n", + " nn.init.constant_(m.bias, 0)\n", + " nn.init.constant_(m.weight, 1.0)\n", + "\n", + " @torch.jit.ignore\n", + " def no_weight_decay(self):\n", + " return {'absolute_pos_embed'}\n", + "\n", + " @torch.jit.ignore\n", + " def no_weight_decay_keywords(self):\n", + " return {'relative_position_bias_table'}\n", + "\n", + " def forward_features(self, x):\n", + " frames_num = x.shape[2] \n", + " x = self.patch_embed(x)\n", + " if self.ape:\n", + " x = x + self.absolute_pos_embed\n", + " x = self.pos_drop(x)\n", + " for i, layer in enumerate(self.layers):\n", + " x, attn = layer(x)\n", + "\n", + " if self.config.enable_tscam:\n", + " # for x\n", + " x = self.norm(x)\n", + " B, N, C = x.shape\n", + " SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]\n", + " ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]\n", + " x = x.permute(0,2,1).contiguous().reshape(B, C, SF, ST)\n", + " B, C, F, T = x.shape\n", + " # group 2D CNN\n", + " c_freq_bin = F // self.freq_ratio\n", + " x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)\n", + " x = x.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)\n", + "\n", + " # get latent_output\n", + " latent_output = self.avgpool(torch.flatten(x,2))\n", + " latent_output = torch.flatten(latent_output, 1)\n", + "\n", + " # display the attention map, if needed\n", + " if self.config.htsat_attn_heatmap:\n", + " # for attn\n", + " attn = torch.mean(attn, dim = 1)\n", + " attn = torch.mean(attn, dim = 1)\n", + " attn = attn.reshape(B, SF, ST)\n", + " c_freq_bin = SF // self.freq_ratio\n", + " attn = attn.reshape(B, SF // c_freq_bin, c_freq_bin, ST) \n", + " attn = attn.permute(0,2,1,3).contiguous().reshape(B, c_freq_bin, -1)\n", + " attn = attn.mean(dim = 1)\n", + " attn_max = torch.max(attn, dim = 1, keepdim = True)[0]\n", + " attn_min = torch.min(attn, dim = 1, keepdim = True)[0]\n", + " attn = ((attn * 0.15) + (attn_max * 0.85 - attn_min)) / (attn_max - attn_min)\n", + " attn = attn.unsqueeze(dim = 2)\n", + "\n", + " x = self.tscam_conv(x)\n", + " x = torch.flatten(x, 2) # B, C, T\n", + "\n", + " if self.config.htsat_attn_heatmap:\n", + " fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous() * attn, 8 * self.patch_stride[1]) \n", + " else: \n", + " fpx = interpolate(torch.sigmoid(x).permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) \n", + " \n", + " x = self.avgpool(x)\n", + " x = torch.flatten(x, 1)\n", + "\n", + " if self.config.loss_type == \"clip_ce\":\n", + " output_dict = {\n", + " 'framewise_output': fpx, # already sigmoided\n", + " 'clipwise_output': x,\n", + " 'latent_output': latent_output\n", + " }\n", + " else:\n", + " output_dict = {\n", + " 'framewise_output': fpx, # already sigmoided\n", + " 'clipwise_output': torch.sigmoid(x),\n", + " 'latent_output': latent_output\n", + " }\n", + " \n", + " else:\n", + " x = self.norm(x) # B N C\n", + " B, N, C = x.shape\n", + " \n", + " fpx = x.permute(0,2,1).contiguous().reshape(B, C, frames_num // (2 ** (len(self.depths) + 1)), frames_num // (2 ** (len(self.depths) + 1)) )\n", + " B, C, F, T = fpx.shape\n", + " c_freq_bin = F // self.freq_ratio\n", + " fpx = fpx.reshape(B, C, F // c_freq_bin, c_freq_bin, T)\n", + " fpx = fpx.permute(0,1,3,2,4).contiguous().reshape(B, C, c_freq_bin, -1)\n", + " fpx = torch.sum(fpx, dim = 2)\n", + " fpx = interpolate(fpx.permute(0,2,1).contiguous(), 8 * self.patch_stride[1]) \n", + " x = self.avgpool(x.transpose(1, 2)) # B C 1\n", + " x = torch.flatten(x, 1)\n", + " if self.num_classes > 0:\n", + " x = self.head(x)\n", + " fpx = self.head(fpx)\n", + " output_dict = {'framewise_output': torch.sigmoid(fpx), \n", + " 'clipwise_output': torch.sigmoid(x)}\n", + " return output_dict\n", + "\n", + " def crop_wav(self, x, crop_size, spe_pos = None):\n", + " time_steps = x.shape[2]\n", + " tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)\n", + " for i in range(len(x)):\n", + " if spe_pos is None:\n", + " crop_pos = random.randint(0, time_steps - crop_size - 1)\n", + " else:\n", + " crop_pos = spe_pos\n", + " tx[i][0] = x[i, 0, crop_pos:crop_pos + crop_size,:]\n", + " return tx\n", + "\n", + " # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model\n", + " def reshape_wav2img(self, x):\n", + " B, C, T, F = x.shape\n", + " target_T = int(self.spec_size * self.freq_ratio)\n", + " target_F = self.spec_size // self.freq_ratio\n", + " assert T <= target_T and F <= target_F, \"the wav size should less than or equal to the swin input size\"\n", + " # to avoid bicubic zero error\n", + " if T < target_T:\n", + " x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode=\"bicubic\", align_corners=True)\n", + " if F < target_F:\n", + " x = nn.functional.interpolate(x, (x.shape[2], target_F), mode=\"bicubic\", align_corners=True)\n", + " x = x.permute(0,1,3,2).contiguous()\n", + " x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)\n", + " # print(x.shape)\n", + " x = x.permute(0,1,3,2,4).contiguous()\n", + " x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])\n", + " return x\n", + " \n", + " # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model\n", + " def repeat_wat2img(self, x, cur_pos):\n", + " B, C, T, F = x.shape\n", + " target_T = int(self.spec_size * self.freq_ratio)\n", + " target_F = self.spec_size // self.freq_ratio\n", + " assert T <= target_T and F <= target_F, \"the wav size should less than or equal to the swin input size\"\n", + " # to avoid bicubic zero error\n", + " if T < target_T:\n", + " x = nn.functional.interpolate(x, (target_T, x.shape[3]), mode=\"bicubic\", align_corners=True)\n", + " if F < target_F:\n", + " x = nn.functional.interpolate(x, (x.shape[2], target_F), mode=\"bicubic\", align_corners=True) \n", + " x = x.permute(0,1,3,2).contiguous() # B C F T\n", + " x = x[:,:,:,cur_pos:cur_pos + self.spec_size]\n", + " x = x.repeat(repeats = (1,1,4,1))\n", + " return x\n", + "\n", + " def forward(self, x: torch.Tensor, mixup_lambda = None, infer_mode = False):# out_feat_keys: List[str] = None):\n", + " x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)\n", + " x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)\n", + " \n", + " \n", + " x = x.transpose(1, 3)\n", + " x = self.bn0(x)\n", + " x = x.transpose(1, 3)\n", + " if self.training:\n", + " x = self.spec_augmenter(x)\n", + " if self.training and mixup_lambda is not None:\n", + " x = do_mixup(x, mixup_lambda)\n", + " \n", + " if infer_mode:\n", + " # in infer mode. we need to handle different length audio input\n", + " frame_num = x.shape[2]\n", + " target_T = int(self.spec_size * self.freq_ratio)\n", + " repeat_ratio = math.floor(target_T / frame_num)\n", + " x = x.repeat(repeats=(1,1,repeat_ratio,1))\n", + " x = self.reshape_wav2img(x)\n", + " output_dict = self.forward_features(x)\n", + " elif self.config.enable_repeat_mode:\n", + " if self.training:\n", + " cur_pos = random.randint(0, (self.freq_ratio - 1) * self.spec_size - 1)\n", + " x = self.repeat_wat2img(x, cur_pos)\n", + " output_dict = self.forward_features(x)\n", + " else:\n", + " output_dicts = []\n", + " for cur_pos in range(0, (self.freq_ratio - 1) * self.spec_size + 1, self.spec_size):\n", + " tx = x.clone()\n", + " tx = self.repeat_wat2img(tx, cur_pos)\n", + " output_dicts.append(self.forward_features(tx))\n", + " clipwise_output = torch.zeros_like(output_dicts[0][\"clipwise_output\"]).float().to(x.device)\n", + " framewise_output = torch.zeros_like(output_dicts[0][\"framewise_output\"]).float().to(x.device)\n", + " for d in output_dicts:\n", + " clipwise_output += d[\"clipwise_output\"]\n", + " framewise_output += d[\"framewise_output\"]\n", + " clipwise_output = clipwise_output / len(output_dicts)\n", + " framewise_output = framewise_output / len(output_dicts)\n", + "\n", + " output_dict = {\n", + " 'framewise_output': framewise_output, \n", + " 'clipwise_output': clipwise_output\n", + " }\n", + " else:\n", + " if x.shape[2] > self.freq_ratio * self.spec_size:\n", + " if self.training:\n", + " x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)\n", + " x = self.reshape_wav2img(x)\n", + " output_dict = self.forward_features(x)\n", + " else:\n", + " # Change: Hard code here\n", + " overlap_size = 344 #(x.shape[2] - 1) // 4\n", + " output_dicts = []\n", + " crop_size = 689 #(x.shape[2] - 1) // 2\n", + " for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):\n", + " tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)\n", + " tx = self.reshape_wav2img(tx)\n", + " output_dicts.append(self.forward_features(tx))\n", + " clipwise_output = torch.zeros_like(output_dicts[0][\"clipwise_output\"]).float().to(x.device)\n", + " framewise_output = torch.zeros_like(output_dicts[0][\"framewise_output\"]).float().to(x.device)\n", + " latent_output = torch.zeros_like(output_dicts[0][\"latent_output\"]).float().to(x.device)\n", + " for d in output_dicts:\n", + " clipwise_output += d[\"clipwise_output\"]\n", + " framewise_output += d[\"framewise_output\"]\n", + " latent_output += d[\"latent_output\"]\n", + " clipwise_output = clipwise_output / len(output_dicts)\n", + " framewise_output = framewise_output / len(output_dicts)\n", + " latent_output = latent_output / len(output_dicts)\n", + " output_dict = {\n", + " 'framewise_output': framewise_output, \n", + " 'clipwise_output': clipwise_output,\n", + " 'latent_output': latent_output,\n", + " }\n", + " else: # this part is typically used, and most easy one\n", + " x = self.reshape_wav2img(x)\n", + " output_dict = self.forward_features(x)\n", + " # x = self.head(x)\n", + " return output_dict\n", + "\n", + "class HTSATWrapper(nn.Module):\n", + " def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, \n", + " fmax, classes_num, out_emb):\n", + " super().__init__()\n", + "\n", + " # print(\"parameters are being overidden when using HTSAT\")\n", + " # print(\"HTSAT only support loading a pretrained model on AudioSet\")\n", + " # @TODO later look at what parameters are same and can be merged\n", + "\n", + " self.htsat = HTSAT_Swin_Transformer(config=HTSATConfig())\n", + "\n", + " def forward(self, x):\n", + " out_dict = self.htsat(x)\n", + " out_dict['embedding'] = out_dict['latent_output']\n", + " return out_dict\n", + "\n", + "\n", + "def get_audio_encoder(name: str):\n", + " if name == \"HTSAT\":\n", + " return HTSATWrapper\n", + " else:\n", + " raise Exception('The audio encoder name {} is incorrect or not supported'.format(name))\n", + "\n", + "class Projection(nn.Module):\n", + " def __init__(self, dim_imgn: int, d_out: int, p: float=0.5) -> None:\n", + " super().__init__()\n", + " self.linear1 = nn.Linear(dim_imgn, d_out, bias=False)\n", + " self.linear2 = nn.Linear(d_out, d_out, bias=False)\n", + " self.layer_norm = nn.LayerNorm(d_out)\n", + " self.drop = nn.Dropout(p)\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " embed1 = self.linear1(x)\n", + " embed2 = self.drop(self.linear2(F.gelu(embed1)))\n", + " embeds = self.layer_norm(embed1 + embed2)\n", + " return embeds\n", + "\n", + "class AudioEncoder(nn.Module):\n", + " def __init__(self, audioenc_name:str, dim_imgn: int, d_out: int, sample_rate: int, window_size: int,\n", + " hop_size: int, mel_bins: int, fmin: int, fmax: int, classes_num: int) -> None:\n", + " super().__init__()\n", + "\n", + " audio_encoder = get_audio_encoder(audioenc_name)\n", + "\n", + " self.base = audio_encoder(\n", + " sample_rate, window_size,\n", + " hop_size, mel_bins, fmin, fmax,\n", + " classes_num, dim_imgn)\n", + "\n", + " self.projection = Projection(dim_imgn, d_out)\n", + "\n", + " def forward(self, x):\n", + " out_dict = self.base(x)\n", + " audio_features, audio_classification_output = out_dict['embedding'], out_dict['clipwise_output']\n", + " projected_vec = self.projection(audio_features)\n", + " return projected_vec, audio_classification_output\n", + "\n", + "class TextEncoder(nn.Module):\n", + " def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:\n", + " super().__init__()\n", + " self.text_model = text_model\n", + " self.base = AutoModel.from_pretrained(text_model)\n", + "\n", + " if 'clip' in text_model:\n", + " self.clip_text_projection = self.base.text_projection\n", + " self.base = self.base.text_model\n", + " if 'base' in text_model:\n", + " transformer_embed_dim = 512\n", + " \n", + " self.projection = Projection(transformer_embed_dim, d_out)\n", + "\n", + " def forward(self, x):\n", + " if 'clip' in self.text_model:\n", + " pooled_output = self.base(**x)[1] # get pooled output\n", + " out = self.clip_text_projection(pooled_output) # get CLS token output\n", + " elif 'gpt' in self.text_model:\n", + " batch_size = x['input_ids'].shape[0]\n", + " hidden_states = self.base(**x)[0] # (batch_size=4, seq_len, 768)\n", + "\n", + " sequence_lengths = torch.ne(x['input_ids'], 0).sum(-1) - 1 # tensor([13, 14, 18, 17])\n", + " out = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths] # [batch_size, 768] = [4, 768]\n", + " else:\n", + " out = self.base(**x)[0]\n", + " out = out[:, 0, :] # get CLS token output\n", + " \n", + " projected_vec = self.projection(out)\n", + "\n", + " return projected_vec\n", + "\n", + "class CLAP(nn.Module):\n", + " def __init__(self,\n", + " # audio\n", + " audioenc_name: str,\n", + " sample_rate: int, \n", + " window_size: int, \n", + " hop_size: int, \n", + " mel_bins: int, \n", + " fmin: int, \n", + " fmax: int, \n", + " classes_num: int, \n", + " out_emb: int,\n", + " # text\n", + " text_model: str,\n", + " transformer_embed_dim: int,\n", + " # common\n", + " d_proj: int,\n", + " ):\n", + " super().__init__()\n", + "\n", + " \n", + " self.audio_encoder = AudioEncoder(\n", + " audioenc_name, out_emb, d_proj,\n", + " sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num)\n", + "\n", + " self.caption_encoder = TextEncoder(\n", + " d_proj, text_model, transformer_embed_dim\n", + " )\n", + "\n", + " self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n", + "\n", + " def forward(self, audio, text):\n", + " audio_embed, _ = self.audio_encoder(audio)\n", + " caption_embed = self.caption_encoder(text)\n", + "\n", + " return caption_embed, audio_embed, self.logit_scale.exp()\n", + " \n", + " \n", + " \n", + "# ==================================================================\n", + "# A U D I O - P R E - P R O C E S S I N G\n", + "# ==================================================================\n", + "def read_audio(audio_path, resample=True):\n", + " r\"\"\"Loads audio file or array and returns a torch tensor\"\"\"\n", + " # Randomly sample a segment of audio_duration from the clip or pad to match duration\n", + " audio_time_series, sample_rate = torchaudio.load(audio_path)\n", + "\n", + " resample_rate = clapConfig.sample_rate\n", + " if resample and resample_rate != sample_rate:\n", + " resampler = T.Resample(sample_rate, resample_rate)\n", + " audio_time_series = resampler(audio_time_series)\n", + " return audio_time_series, resample_rate\n", + "\n", + "def load_audio_into_tensor(audio_path, audio_duration, resample=False):\n", + " r\"\"\"Loads audio file and returns raw audio.\"\"\"\n", + " # Randomly sample a segment of audio_duration from the clip or pad to match duration\n", + " audio_time_series, sample_rate = read_audio(audio_path, resample)\n", + " audio_time_series = audio_time_series.reshape(-1)\n", + "\n", + " # audio_time_series is shorter than predefined audio duration,\n", + " # so audio_time_series is extended\n", + " if audio_duration*sample_rate >= audio_time_series.shape[0]:\n", + " repeat_factor = int(np.ceil((audio_duration*sample_rate) /\n", + " audio_time_series.shape[0]))\n", + " # Repeat audio_time_series by repeat_factor to match audio_duration\n", + " audio_time_series = audio_time_series.repeat(repeat_factor)\n", + " # remove excess part of audio_time_series\n", + " audio_time_series = audio_time_series[0:audio_duration*sample_rate]\n", + " else:\n", + " # audio_time_series is longer than predefined audio duration,\n", + " # so audio_time_series is trimmed\n", + " start_index = random.randrange(\n", + " audio_time_series.shape[0] - audio_duration*sample_rate)\n", + " audio_time_series = audio_time_series[start_index:start_index +\n", + " audio_duration*sample_rate]\n", + " return torch.FloatTensor(audio_time_series)\n", + "\n", + "np_str_obj_array_pattern = re.compile(r'[SaUO]')\n", + "default_collate_err_msg_format = (\n", + " \"default_collate: batch must contain tensors, numpy arrays, numbers, \"\n", + " \"dicts or lists; found {}\")\n", + "\n", + "def default_collate(batch):\n", + " r\"\"\"Puts each data field into a tensor with outer dimension batch size\"\"\"\n", + " elem = batch[0]\n", + " elem_type = type(elem)\n", + " if isinstance(elem, torch.Tensor):\n", + " out = None\n", + " if torch.utils.data.get_worker_info() is not None:\n", + " # If we're in a background process, concatenate directly into a\n", + " # shared memory tensor to avoid an extra copy\n", + " numel = sum([x.numel() for x in batch])\n", + " storage = elem.storage()._new_shared(numel)\n", + " out = elem.new(storage)\n", + " return torch.stack(batch, 0, out=out)\n", + " elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \\\n", + " and elem_type.__name__ != 'string_':\n", + " if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':\n", + " # array of string classes and object\n", + " if np_str_obj_array_pattern.search(elem.dtype.str) is not None:\n", + " raise TypeError(\n", + " default_collate_err_msg_format.format(elem.dtype))\n", + "\n", + " return default_collate([torch.as_tensor(b) for b in batch])\n", + " elif elem.shape == (): # scalars\n", + " return torch.as_tensor(batch)\n", + " elif isinstance(elem, float):\n", + " return torch.tensor(batch, dtype=torch.float64)\n", + " elif isinstance(elem, int):\n", + " return torch.tensor(batch)\n", + " elif isinstance(elem, str):\n", + " return batch\n", + " elif isinstance(elem, collections.abc.Mapping):\n", + " return {key: default_collate([d[key] for d in batch]) for key in elem}\n", + " elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple\n", + " return elem_type(*(default_collate(samples) for samples in zip(*batch)))\n", + " elif isinstance(elem, collections.abc.Sequence):\n", + " # check to make sure that the elements in batch have consistent size\n", + " it = iter(batch)\n", + " elem_size = len(next(it))\n", + " if not all(len(elem) == elem_size for elem in it):\n", + " raise RuntimeError(\n", + " 'each element in list of batch should be of equal size')\n", + " transposed = zip(*batch)\n", + " return [default_collate(samples) for samples in transposed]\n", + "\n", + " raise TypeError(default_collate_err_msg_format.format(elem_type))\n", + "\n", + "def preprocess_audio(audio_files, resample):\n", + " r\"\"\"Load list of audio files and return raw audio\"\"\"\n", + " audio_tensors = []\n", + " for audio_file in audio_files:\n", + " audio_tensor = load_audio_into_tensor(\n", + " audio_file, clapConfig.duration, resample)\n", + " audio_tensor = audio_tensor.reshape(1, -1)\n", + " audio_tensors.append(audio_tensor)\n", + " return default_collate(audio_tensors)\n", + "\n", + "\n", + "\n", + "# ==================================================================\n", + "# A U D I O - E M B E D D I N G S - H E L P E R\n", + "# ==================================================================\n", + "def CLAPAudioProcessor(audio_files: List[str], resample=True):\n", + " preprocessed_audio = preprocess_audio(audio_files, resample)\n", + " preprocessed_audio = preprocessed_audio.reshape(\n", + " preprocessed_audio.shape[0], preprocessed_audio.shape[2])\n", + " preprocessed_audio = preprocessed_audio\n", + " return preprocessed_audio\n", + "\n", + "def get_audio_embeddings(audio_files: List[str], audio_encoder, resample=True):\n", + " \"\"\"Load list of audio files and return audio embeddings\"\"\"\n", + " # preprocessed_audio = preprocess_audio(audio_files, resample)\n", + " # with torch.no_grad():\n", + " # preprocessed_audio = preprocessed_audio.reshape(\n", + " # preprocessed_audio.shape[0], preprocessed_audio.shape[2])\n", + " with torch.no_grad():\n", + " preprocessed_audio = CLAPAudioProcessor(audio_files, resample)\n", + " return audio_encoder(preprocessed_audio)[0]\n", + "\n", + "\n", + "# ==================================================================\n", + "# C L A P\n", + "# ==================================================================\n", + "class ClapConfig:\n", + " # TEXT ENCODER CONFIG\n", + " text_model = 'gpt2'\n", + " text_len = 77\n", + " transformer_embed_dim = 768\n", + " freeze_text_encoder_weights = True\n", + "\n", + " # AUDIO ENCODER CONFIG\n", + " audioenc_name = 'HTSAT'\n", + " out_emb = 768\n", + " sample_rate = 44100\n", + " duration = 7\n", + " fmin = 50\n", + " fmax = 8000 # 14000\n", + " n_fft = 1024 # 1028\n", + " hop_size = 320\n", + " mel_bins = 64\n", + " window_size = 1024\n", + "\n", + " # PROJECTION SPACE CONFIG \n", + " d_proj = 1024\n", + " temperature = 0.003\n", + "\n", + " # TRAINING AND EVALUATION CONFIG\n", + " num_classes = 527\n", + " batch_size = 1024\n", + " demo = False\n", + " \n", + "\n", + "clapConfig = ClapConfig()\n", + "clap = CLAP(\n", + " audioenc_name=clapConfig.audioenc_name,\n", + " sample_rate=clapConfig.sample_rate,\n", + " window_size=clapConfig.window_size,\n", + " hop_size=clapConfig.hop_size,\n", + " mel_bins=clapConfig.mel_bins,\n", + " fmin=clapConfig.fmin,\n", + " fmax=clapConfig.fmax,\n", + " classes_num=clapConfig.num_classes,\n", + " out_emb=clapConfig.out_emb,\n", + " text_model=clapConfig.text_model,\n", + " transformer_embed_dim=clapConfig.transformer_embed_dim,\n", + " d_proj=clapConfig.d_proj\n", + " )\n", + "\n", + "model_repo = \"microsoft/msclap\"\n", + "model_name = {\n", + " '2022': 'CLAP_weights_2022.pth',\n", + " '2023': 'CLAP_weights_2023.pth',\n", + " 'clapcap': 'clapcap_weights_2023.pth'\n", + "}\n", + "\n", + "version = '2023'\n", + "model_fp = hf_hub_download(model_repo, model_name[version])\n", + "\n", + "model_state_dict = torch.load(model_fp, map_location=torch.device('cpu'))['model']\n", + "clap.load_state_dict(model_state_dict, strict=False)\n", + "# clap.eval()\n", + "\n", + "clap_audio_encoder = clap.audio_encoder.to(device)\n", + "\n", + "# ENGLISH_AUDIO_DIR = r\"/home/IITB/ai-at-ieor/23m1521/datasets/Vaani/Audios/English\"\n", + "# audio_files = [os.path.join(ENGLISH_AUDIO_DIR, i) for i in os.listdir(ENGLISH_AUDIO_DIR) if i.endswith(\".wav\")]\n", + "# audio_embedding = get_audio_embeddings(audio_files, clap_audio_encoder)\n", + "# print(\"CLAP Audio Encoder Embeddings:\", audio_embedding.shape) # [5, 1024]\n", + "\n", + "\n", + "# ==================================================================\n", + "# C L A P - L o R A - M O D E L\n", + "# ==================================================================\n", + "LoRAconfig = {\n", + " \"peft_type\": \"LORA\",\n", + " \"task_type\": \"FEATURE_EXTRACTION\",\n", + " \"inference_mode\": False,\n", + " \"r\": 16,\n", + " \"target_modules\": [\"qkv\", \"fc1\", \"fc2\", \"proj\", \"linear1\", \"linear2\"],\n", + " \"lora_alpha\": 32,\n", + " \"lora_dropout\": 0.05,\n", + " \"fan_in_fan_out\": False,\n", + " \"bias\": \"all\",\n", + "}\n", + "peft_config = get_peft_config(LoRAconfig)\n", + "\n", + "peft_model = get_peft_model(clap_audio_encoder, peft_config)\n", + "\n", + "peft_model.print_trainable_parameters()\n", + "\n", + "peft_clap_audio_encoder = peft_model.base_model\n", + "# audio_embedding = get_audio_embeddings(audio_files, peft_clap_audio_encoder)\n", + "# print(\"CLAP LoRA Audio Encoder Embeddings:\", audio_embedding.shape) # [5, 1024]\n", + "\n", + "\n", + "\n", + "# ==================================================================\n", + "# O P E N - C L I P - M O D E L\n", + "# ==================================================================\n", + "# ==================================================================\n", + "# I M P O R T S\n", + "# ==================================================================\n", + "\n", + "\n", + "import os\n", + "import io\n", + "import sys\n", + "import math\n", + "import random\n", + "import collections\n", + "import collections.abc\n", + "import re\n", + "from itertools import repeat\n", + "from pathlib import Path\n", + "from typing import Optional, Tuple, Union, List, Dict\n", + "\n", + "import csv\n", + "import numpy as np\n", + "import pandas as pd\n", + "from PIL import Image\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "from tqdm import trange, tqdm\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import nn\n", + "from torch.nn.init import _calculate_fan_in_and_fan_out\n", + "import torch.utils.checkpoint as checkpoint\n", + "\n", + "import torchvision\n", + "from torchvision.transforms import v2\n", + "\n", + "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", + "# device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "# print(f\"Using device: {device}\")\n", + "\n", + "import torchaudio\n", + "import torchaudio.transforms as T\n", + "from torchlibrosa.stft import Spectrogram, LogmelFilterBank\n", + "from torchlibrosa.augmentation import SpecAugmentation\n", + "\n", + "from transformers import AutoModel, AutoTokenizer, logging\n", + "from huggingface_hub.file_download import hf_hub_download\n", + "from huggingface_hub.file_download import hf_hub_download\n", + "from peft import get_peft_config, get_peft_model\n", + "\n", + "from typing import Any, Dict, Optional, Tuple, Union\n", + "import numbers\n", + "import random\n", + "import warnings\n", + "from dataclasses import dataclass, asdict\n", + "from typing import Any, Dict, List, Optional, Sequence, Tuple, Union\n", + "\n", + "import torch\n", + "import torchvision.transforms.functional as F\n", + "from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \\\n", + " CenterCrop, ColorJitter, Grayscale\n", + "\n", + "OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)\n", + "OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)\n", + "IMAGENET_MEAN = (0.485, 0.456, 0.406)\n", + "IMAGENET_STD = (0.229, 0.224, 0.225)\n", + "INCEPTION_MEAN = (0.5, 0.5, 0.5)\n", + "INCEPTION_STD = (0.5, 0.5, 0.5)\n", + "\n", + "# Default name for a weights file hosted on the Huggingface Hub.\n", + "HF_WEIGHTS_NAME = \"open_clip_pytorch_model.bin\" # default pytorch pkl\n", + "HF_SAFE_WEIGHTS_NAME = \"open_clip_model.safetensors\" # safetensors version\n", + "HF_CONFIG_NAME = 'open_clip_config.json'\n", + "\n", + "\n", + "import collections.abc\n", + "from itertools import repeat\n", + "from typing import List, Optional, Tuple, Union\n", + "\n", + "import torch\n", + "from torch import nn as nn\n", + "from torch import _assert\n", + "from torchvision.ops.misc import FrozenBatchNorm2d\n", + "\n", + "\n", + "def freeze_batch_norm_2d(module, module_match={}, name=''):\n", + " \"\"\"\n", + " Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is\n", + " itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and\n", + " returned. Otherwise, the module is walked recursively and submodules are converted in place.\n", + "\n", + " Args:\n", + " module (torch.nn.Module): Any PyTorch module.\n", + " module_match (dict): Dictionary of full module names to freeze (all if empty)\n", + " name (str): Full module name (prefix)\n", + "\n", + " Returns:\n", + " torch.nn.Module: Resulting module\n", + "\n", + " Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762\n", + " \"\"\"\n", + " res = module\n", + " is_match = True\n", + " if module_match:\n", + " is_match = name in module_match\n", + " if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):\n", + " res = FrozenBatchNorm2d(module.num_features)\n", + " res.num_features = module.num_features\n", + " res.affine = module.affine\n", + " if module.affine:\n", + " res.weight.data = module.weight.data.clone().detach()\n", + " res.bias.data = module.bias.data.clone().detach()\n", + " res.running_mean.data = module.running_mean.data\n", + " res.running_var.data = module.running_var.data\n", + " res.eps = module.eps\n", + " else:\n", + " for child_name, child in module.named_children():\n", + " full_child_name = '.'.join([name, child_name]) if name else child_name\n", + " new_child = freeze_batch_norm_2d(child, module_match, full_child_name)\n", + " if new_child is not child:\n", + " res.add_module(child_name, new_child)\n", + " return res\n", + "\n", + "\n", + "# From PyTorch internals\n", + "def _ntuple(n):\n", + " def parse(x):\n", + " if isinstance(x, collections.abc.Iterable):\n", + " return x\n", + " return tuple(repeat(x, n))\n", + " return parse\n", + "\n", + "\n", + "to_1tuple = _ntuple(1)\n", + "to_2tuple = _ntuple(2)\n", + "to_3tuple = _ntuple(3)\n", + "to_4tuple = _ntuple(4)\n", + "to_ntuple = lambda n, x: _ntuple(n)(x)\n", + "\n", + "# Replaces all linear layers with linear_replacement\n", + "# TODO: add int8 support for other linear layers including attn and convnets\n", + "def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):\n", + " for name, module in model.named_children():\n", + " if len(list(module.children())) > 0:\n", + " replace_linear(module, linear_replacement, include_modules, copy_weights)\n", + "\n", + " if isinstance(module, torch.nn.Linear) and name in include_modules:\n", + " old_module = model._modules[name]\n", + " model._modules[name] = linear_replacement(\n", + " module.in_features,\n", + " module.out_features,\n", + " module.bias is not None,\n", + " )\n", + " if copy_weights:\n", + " model._modules[name].weight.data.copy_(old_module.weight.data)\n", + " if model._modules[name].bias is not None:\n", + " model._modules[name].bias.data.copy_(old_module.bias)\n", + "\n", + " return model\n", + "\n", + "def convert_int8_model_to_inference_mode(model):\n", + " for m in model.modules():\n", + " if hasattr(m, 'prepare_for_eval'):\n", + " int8_original_dtype = m.weight.dtype\n", + " m.prepare_for_eval()\n", + " m.int8_original_dtype = int8_original_dtype\n", + "\n", + "\n", + "def feature_take_indices(\n", + " num_features: int,\n", + " indices: Optional[Union[int, List[int]]] = None,\n", + " as_set: bool = False,\n", + ") -> Tuple[List[int], int]:\n", + " \"\"\" Determine the absolute feature indices to 'take' from.\n", + "\n", + " Note: This function can be called in forward() so must be torchscript compatible,\n", + " which requires some incomplete typing and workaround hacks.\n", + "\n", + " Args:\n", + " num_features: total number of features to select from\n", + " indices: indices to select,\n", + " None -> select all\n", + " int -> select last n\n", + " list/tuple of int -> return specified (-ve indices specify from end)\n", + " as_set: return as a set\n", + "\n", + " Returns:\n", + " List (or set) of absolute (from beginning) indices, Maximum index\n", + " \"\"\"\n", + " if indices is None:\n", + " indices = num_features # all features if None\n", + "\n", + " if isinstance(indices, int):\n", + " # convert int -> last n indices\n", + " _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')\n", + " take_indices = [num_features - indices + i for i in range(indices)]\n", + " else:\n", + " take_indices: List[int] = []\n", + " for i in indices:\n", + " idx = num_features + i if i < 0 else i\n", + " _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')\n", + " take_indices.append(idx)\n", + "\n", + " if not torch.jit.is_scripting() and as_set:\n", + " return set(take_indices), max(take_indices)\n", + "\n", + " return take_indices, max(take_indices)\n", + "\n", + "\n", + "def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:\n", + " if isinstance(x, int):\n", + " # if indices is an int, take last N features\n", + " return tuple(range(-x, 0))\n", + " return tuple(x)\n", + "\n", + "\n", + "\n", + "import copy\n", + "import copy\n", + "import hashlib\n", + "import os\n", + "import urllib\n", + "import warnings\n", + "from functools import partial\n", + "from typing import Dict, Iterable, Optional, Union\n", + "\n", + "from tqdm import tqdm\n", + "\n", + "\n", + "try:\n", + " import safetensors.torch\n", + " _has_safetensors = True\n", + "except ImportError:\n", + " _has_safetensors = False\n", + "\n", + "__version__ = '2.32.0'\n", + "\n", + "\n", + "\"\"\" CLIP Model\n", + "\n", + "Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.\n", + "\"\"\"\n", + "import copy\n", + "import logging\n", + "import math\n", + "from dataclasses import dataclass\n", + "from typing import Any, Dict, List, Optional, Tuple, Union\n", + "\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import nn\n", + "from torch.utils.checkpoint import checkpoint\n", + "from functools import partial\n", + "\n", + "# from .hf_model import HFTextEncoder\n", + "# from .modified_resnet import ModifiedResNet\n", + "from collections import OrderedDict\n", + "import math\n", + "from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "from torch.utils.checkpoint import checkpoint\n", + "\n", + "# from .utils import to_2tuple, feature_take_indices\n", + "# from .pos_embed import get_2d_sincos_pos_embed\n", + "# Copyright (c) Meta Platforms, Inc. and affiliates.\n", + "# All rights reserved.\n", + "\n", + "# This source code is licensed under the license found in the\n", + "# LICENSE file in the root directory of this source tree.\n", + "# --------------------------------------------------------\n", + "# Position embedding utils\n", + "# --------------------------------------------------------\n", + "\n", + "import numpy as np\n", + "\n", + "import torch\n", + "\n", + "# --------------------------------------------------------\n", + "# 2D sine-cosine position embedding\n", + "# References:\n", + "# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py\n", + "# MoCo v3: https://github.com/facebookresearch/moco-v3\n", + "# --------------------------------------------------------\n", + "def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):\n", + " \"\"\"\n", + " grid_size: int of the grid height and width\n", + " return:\n", + " pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)\n", + " \"\"\"\n", + " grid_h = np.arange(grid_size, dtype=np.float32)\n", + " grid_w = np.arange(grid_size, dtype=np.float32)\n", + " grid = np.meshgrid(grid_w, grid_h) # here w goes first\n", + " grid = np.stack(grid, axis=0)\n", + "\n", + " grid = grid.reshape([2, 1, grid_size, grid_size])\n", + " pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)\n", + " if cls_token:\n", + " pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)\n", + " return pos_embed\n", + "\n", + "\n", + "def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):\n", + " assert embed_dim % 2 == 0\n", + "\n", + " # use half of dimensions to encode grid_h\n", + " emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)\n", + " emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)\n", + "\n", + " emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)\n", + " return emb\n", + "\n", + "\n", + "def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):\n", + " \"\"\"\n", + " embed_dim: output dimension for each position\n", + " pos: a list of positions to be encoded: size (M,)\n", + " out: (M, D)\n", + " \"\"\"\n", + " assert embed_dim % 2 == 0\n", + " omega = np.arange(embed_dim // 2, dtype=float)\n", + " omega /= embed_dim / 2.\n", + " omega = 1. / 10000**omega # (D/2,)\n", + "\n", + " pos = pos.reshape(-1) # (M,)\n", + " out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product\n", + "\n", + " emb_sin = np.sin(out) # (M, D/2)\n", + " emb_cos = np.cos(out) # (M, D/2)\n", + "\n", + " emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)\n", + " return emb\n", + "\n", + "\n", + "# --------------------------------------------------------\n", + "# Interpolate position embeddings for high-resolution\n", + "# References:\n", + "# DeiT: https://github.com/facebookresearch/deit\n", + "# --------------------------------------------------------\n", + "def interpolate_pos_embed(model, checkpoint_model):\n", + " if 'pos_embed' in checkpoint_model:\n", + " pos_embed_checkpoint = checkpoint_model['pos_embed']\n", + " embedding_size = pos_embed_checkpoint.shape[-1]\n", + " num_patches = model.patch_embed.num_patches\n", + " num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n", + " # height (== width) for the checkpoint position embedding\n", + " orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n", + " # height (== width) for the new position embedding\n", + " new_size = int(num_patches ** 0.5)\n", + " # class_token and dist_token are kept unchanged\n", + " if orig_size != new_size:\n", + " print(\"Position interpolate from %dx%d to %dx%d\" % (orig_size, orig_size, new_size, new_size))\n", + " extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n", + " # only the position tokens are interpolated\n", + " pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n", + " pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)\n", + " pos_tokens = torch.nn.functional.interpolate(\n", + " pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)\n", + " pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n", + " new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n", + " checkpoint_model['pos_embed'] = new_pos_embed\n", + "\n", + "\n", + "\n", + "from collections import OrderedDict\n", + "from typing import Dict, List, Optional, Union\n", + "\n", + "import torch\n", + "from torch import nn\n", + "from torch.nn import functional as F\n", + "\n", + "# from .utils import freeze_batch_norm_2d, feature_take_indices\n", + "\n", + "\n", + "class Bottleneck(nn.Module):\n", + " expansion = 4\n", + "\n", + " def __init__(self, inplanes, planes, stride=1):\n", + " super().__init__()\n", + "\n", + " # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1\n", + " self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)\n", + " self.bn1 = nn.BatchNorm2d(planes)\n", + " self.act1 = nn.ReLU(inplace=True)\n", + "\n", + " self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)\n", + " self.bn2 = nn.BatchNorm2d(planes)\n", + " self.act2 = nn.ReLU(inplace=True)\n", + "\n", + " self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()\n", + "\n", + " self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)\n", + " self.bn3 = nn.BatchNorm2d(planes * self.expansion)\n", + " self.act3 = nn.ReLU(inplace=True)\n", + "\n", + " self.downsample = None\n", + " self.stride = stride\n", + "\n", + " if stride > 1 or inplanes != planes * Bottleneck.expansion:\n", + " # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1\n", + " self.downsample = nn.Sequential(OrderedDict([\n", + " (\"-1\", nn.AvgPool2d(stride)),\n", + " (\"0\", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),\n", + " (\"1\", nn.BatchNorm2d(planes * self.expansion))\n", + " ]))\n", + "\n", + " def forward(self, x: torch.Tensor):\n", + " identity = x\n", + "\n", + " out = self.act1(self.bn1(self.conv1(x)))\n", + " out = self.act2(self.bn2(self.conv2(out)))\n", + " out = self.avgpool(out)\n", + " out = self.bn3(self.conv3(out))\n", + "\n", + " if self.downsample is not None:\n", + " identity = self.downsample(x)\n", + "\n", + " out += identity\n", + " out = self.act3(out)\n", + " return out\n", + "\n", + "\n", + "class AttentionPool2d(nn.Module):\n", + " def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):\n", + " super().__init__()\n", + " self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)\n", + " self.k_proj = nn.Linear(embed_dim, embed_dim)\n", + " self.q_proj = nn.Linear(embed_dim, embed_dim)\n", + " self.v_proj = nn.Linear(embed_dim, embed_dim)\n", + " self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)\n", + " self.num_heads = num_heads\n", + "\n", + " def forward(self, x):\n", + " x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC\n", + " x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC\n", + " x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC\n", + " x, _ = F.multi_head_attention_forward(\n", + " query=x, key=x, value=x,\n", + " embed_dim_to_check=x.shape[-1],\n", + " num_heads=self.num_heads,\n", + " q_proj_weight=self.q_proj.weight,\n", + " k_proj_weight=self.k_proj.weight,\n", + " v_proj_weight=self.v_proj.weight,\n", + " in_proj_weight=None,\n", + " in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),\n", + " bias_k=None,\n", + " bias_v=None,\n", + " add_zero_attn=False,\n", + " dropout_p=0.,\n", + " out_proj_weight=self.c_proj.weight,\n", + " out_proj_bias=self.c_proj.bias,\n", + " use_separate_proj_weight=True,\n", + " training=self.training,\n", + " need_weights=False\n", + " )\n", + "\n", + " return x[0]\n", + "\n", + "\n", + "class ModifiedResNet(nn.Module):\n", + " \"\"\"\n", + " A ResNet class that is similar to torchvision's but contains the following changes:\n", + " - There are now 3 \"stem\" convolutions as opposed to 1, with an average pool instead of a max pool.\n", + " - Performs antialiasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1\n", + " - The final pooling layer is a QKV attention instead of an average pool\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " layers: List[int],\n", + " output_dim: int,\n", + " heads: int,\n", + " image_size: int = 224,\n", + " width: int = 64,\n", + " ):\n", + " super().__init__()\n", + " self.output_dim = output_dim\n", + " self.image_size = image_size\n", + "\n", + " # the 3-layer stem\n", + " self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)\n", + " self.bn1 = nn.BatchNorm2d(width // 2)\n", + " self.act1 = nn.ReLU(inplace=True)\n", + " self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)\n", + " self.bn2 = nn.BatchNorm2d(width // 2)\n", + " self.act2 = nn.ReLU(inplace=True)\n", + " self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)\n", + " self.bn3 = nn.BatchNorm2d(width)\n", + " self.act3 = nn.ReLU(inplace=True)\n", + " self.avgpool = nn.AvgPool2d(2)\n", + "\n", + " # residual layers\n", + " self._inplanes = width # this is a *mutable* variable used during construction\n", + " self.layer1 = self._make_layer(width, layers[0])\n", + " self.layer2 = self._make_layer(width * 2, layers[1], stride=2)\n", + " self.layer3 = self._make_layer(width * 4, layers[2], stride=2)\n", + " self.layer4 = self._make_layer(width * 8, layers[3], stride=2)\n", + "\n", + " embed_dim = width * 32 # the ResNet feature dimension\n", + " self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)\n", + "\n", + " self.init_parameters()\n", + "\n", + " def _make_layer(self, planes, blocks, stride=1):\n", + " layers = [Bottleneck(self._inplanes, planes, stride)]\n", + "\n", + " self._inplanes = planes * Bottleneck.expansion\n", + " for _ in range(1, blocks):\n", + " layers.append(Bottleneck(self._inplanes, planes))\n", + "\n", + " return nn.Sequential(*layers)\n", + "\n", + " def init_parameters(self):\n", + " if self.attnpool is not None:\n", + " std = self.attnpool.c_proj.in_features ** -0.5\n", + " nn.init.normal_(self.attnpool.q_proj.weight, std=std)\n", + " nn.init.normal_(self.attnpool.k_proj.weight, std=std)\n", + " nn.init.normal_(self.attnpool.v_proj.weight, std=std)\n", + " nn.init.normal_(self.attnpool.c_proj.weight, std=std)\n", + "\n", + " for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:\n", + " for name, param in resnet_block.named_parameters():\n", + " if name.endswith(\"bn3.weight\"):\n", + " nn.init.zeros_(param)\n", + "\n", + " def lock(self, unlocked_groups=0, freeze_bn_stats=False):\n", + " assert unlocked_groups == 0, 'partial locking not currently supported for this model'\n", + " for param in self.parameters():\n", + " param.requires_grad = False\n", + " if freeze_bn_stats:\n", + " freeze_batch_norm_2d(self)\n", + "\n", + " @torch.jit.ignore\n", + " def set_grad_checkpointing(self, enable=True):\n", + " # FIXME support for non-transformer\n", + " pass\n", + "\n", + " def stem(self, x):\n", + " x = self.act1(self.bn1(self.conv1(x)))\n", + " x = self.act2(self.bn2(self.conv2(x)))\n", + " x = self.act3(self.bn3(self.conv3(x)))\n", + " x = self.avgpool(x)\n", + " return x\n", + "\n", + " def forward_intermediates(\n", + " self,\n", + " x: torch.Tensor,\n", + " indices: Optional[Union[int, List[int]]] = None,\n", + " stop_early: bool = False,\n", + " normalize_intermediates: bool = False,\n", + " intermediates_only: bool = False,\n", + " output_fmt: str = 'NCHW',\n", + " output_extra_tokens: bool = False,\n", + " ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", + " \"\"\" Forward features that returns intermediates.\n", + "\n", + " Args:\n", + " x: Input image tensor\n", + " indices: Take last n blocks if int, all if None, select matching indices if sequence\n", + " stop_early: Stop iterating over blocks when last desired intermediate hit\n", + " normalize_intermediates: Apply final norm layer to all intermediates\n", + " intermediates_only: Only return intermediate features\n", + " output_fmt: Shape of intermediate feature outputs\n", + " output_extra_tokens: Return both extra class, eot tokens\n", + " Returns:\n", + "\n", + " \"\"\"\n", + " assert output_fmt in ('NCHW',), 'Output format must be == NCHW.'\n", + " # NOTE normalize_intermediates and return_extra_tokens don't apply\n", + " take_indices, max_index = feature_take_indices(5, indices)\n", + "\n", + " output = {}\n", + " intermediates = []\n", + " blocks = [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]\n", + " if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript\n", + " blocks = blocks[:max_index + 1]\n", + " for i, blk in enumerate(blocks):\n", + " x = blk(x)\n", + " if i in take_indices:\n", + " intermediates.append(x)\n", + "\n", + " output['image_intermediates'] = intermediates\n", + "\n", + " if intermediates_only:\n", + " return output\n", + "\n", + " x = self.attnpool(x)\n", + " output['image_features'] = x\n", + "\n", + " return output\n", + "\n", + " def forward(self, x):\n", + " x = self.stem(x)\n", + " x = self.layer1(x)\n", + " x = self.layer2(x)\n", + " x = self.layer3(x)\n", + " x = self.layer4(x)\n", + " x = self.attnpool(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "\"\"\" huggingface model adapter\n", + "\n", + "Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.\n", + "\"\"\"\n", + "import re\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch import TensorType\n", + "\n", + "try:\n", + " import transformers\n", + " from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig\n", + " from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \\\n", + " BaseModelOutputWithPoolingAndCrossAttentions\n", + "except ImportError as e:\n", + " transformers = None\n", + "\n", + "\n", + " class BaseModelOutput:\n", + " pass\n", + "\n", + "\n", + " class PretrainedConfig:\n", + " pass\n", + "\n", + "# from .hf_configs import arch_dict\n", + "# HF architecture dict:\n", + "arch_dict = {\n", + " # https://huggingface.co/docs/transformers/model_doc/roberta#roberta\n", + " \"roberta\": {\n", + " \"config_names\": {\n", + " \"context_length\": \"max_position_embeddings\",\n", + " \"vocab_size\": \"vocab_size\",\n", + " \"width\": \"hidden_size\",\n", + " \"heads\": \"num_attention_heads\",\n", + " \"layers\": \"num_hidden_layers\",\n", + " \"layer_attr\": \"layer\",\n", + " \"token_embeddings_attr\": \"embeddings\"\n", + " },\n", + " \"pooler\": \"mean_pooler\",\n", + " },\n", + " # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig\n", + " \"xlm-roberta\": {\n", + " \"config_names\": {\n", + " \"context_length\": \"max_position_embeddings\",\n", + " \"vocab_size\": \"vocab_size\",\n", + " \"width\": \"hidden_size\",\n", + " \"heads\": \"num_attention_heads\",\n", + " \"layers\": \"num_hidden_layers\",\n", + " \"layer_attr\": \"layer\",\n", + " \"token_embeddings_attr\": \"embeddings\"\n", + " },\n", + " \"pooler\": \"mean_pooler\",\n", + " },\n", + " # https://huggingface.co/docs/transformers/model_doc/mt5#mt5\n", + " \"mt5\": {\n", + " \"config_names\": {\n", + " # unlimited seqlen\n", + " # https://github.com/google-research/text-to-text-transfer-transformer/issues/273\n", + " # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374\n", + " \"context_length\": \"\",\n", + " \"vocab_size\": \"vocab_size\",\n", + " \"width\": \"d_model\",\n", + " \"heads\": \"num_heads\",\n", + " \"layers\": \"num_layers\",\n", + " \"layer_attr\": \"block\",\n", + " \"token_embeddings_attr\": \"embed_tokens\"\n", + " },\n", + " \"pooler\": \"mean_pooler\",\n", + " },\n", + " # https://huggingface.co/docs/transformers/model_doc/bert\n", + " \"bert\": {\n", + " \"config_names\": {\n", + " \"context_length\": \"max_position_embeddings\",\n", + " \"vocab_size\": \"vocab_size\",\n", + " \"width\": \"hidden_size\",\n", + " \"heads\": \"num_attention_heads\",\n", + " \"layers\": \"num_hidden_layers\",\n", + " },\n", + " \"pooler\": \"cls_pooler\",\n", + " },\n", + " # https://huggingface.co/docs/transformers/model_doc/m2m_100\n", + " \"m2m_100\": {\n", + " \"config_names\": {\n", + " \"context_length\": \"max_position_embeddings\",\n", + " \"vocab_size\": \"vocab_size\",\n", + " \"width\": \"d_model\",\n", + " \"heads\": \"encoder_attention_heads\",\n", + " \"layers\": \"encoder_layers\",\n", + " },\n", + " \"pooler\": \"cls_pooler\",\n", + " },\n", + "}\n", + "\n", + "\n", + "\n", + "# utils\n", + "def _camel2snake(s):\n", + " return re.sub(r'(? Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", + " \"\"\" Forward features that returns intermediates.\n", + "\n", + " Args:\n", + " x: Input image tensor\n", + " indices: Take last n blocks if int, all if None, select matching indices if sequence\n", + " stop_early: Stop iterating over blocks when last desired intermediate hit\n", + " normalize_intermediates: Apply norm layer to all intermediates\n", + " intermediates_only: Only return intermediate features\n", + " output_fmt: Shape of intermediate feature outputs\n", + " output_extra_tokens: Return both prefix and spatial intermediate tokens\n", + " Returns:\n", + " \"\"\"\n", + " extra_args = {}\n", + " if output_extra_tokens:\n", + " extra_args['return_prefix_tokens'] = True\n", + " trunk_output = self.trunk.forward_intermediates(\n", + " x,\n", + " indices=indices,\n", + " intermediates_only=intermediates_only,\n", + " norm=normalize_intermediates,\n", + " stop_early=stop_early,\n", + " output_fmt=output_fmt,\n", + " **extra_args,\n", + " )\n", + "\n", + " return_dict = {}\n", + " intermediates = trunk_output if intermediates_only else trunk_output[1]\n", + " if output_extra_tokens and intermediates and isinstance(intermediates[0], tuple):\n", + " intermediates_prefix = [xi[1] for xi in intermediates]\n", + " intermediates = [xi[0] for xi in intermediates]\n", + " return_dict['image_intermediates_prefix'] = intermediates_prefix\n", + "\n", + " return_dict['image_intermediates'] = intermediates\n", + " if intermediates_only:\n", + " return return_dict\n", + "\n", + " image_features = self.trunk.forward_head(trunk_output[0]) # run through timm pooling / projection\n", + " image_features = self.head(image_features) # run through adapter pooling / projection\n", + " return_dict['image_features'] = image_features\n", + " return return_dict\n", + "\n", + " def forward(self, x):\n", + " x = self.trunk(x)\n", + " x = self.head(x)\n", + " return x\n", + "\n", + "\n", + "class LayerNormFp32(nn.LayerNorm):\n", + " \"\"\"Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).\"\"\"\n", + "\n", + " def forward(self, x: torch.Tensor):\n", + " orig_type = x.dtype\n", + " x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)\n", + " return x.to(orig_type)\n", + "\n", + "\n", + "class LayerNorm(nn.LayerNorm):\n", + " \"\"\"Subclass torch's LayerNorm (with cast back to input dtype).\"\"\"\n", + "\n", + " def forward(self, x: torch.Tensor):\n", + " orig_type = x.dtype\n", + " x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n", + " return x.to(orig_type)\n", + "\n", + "\n", + "class QuickGELU(nn.Module):\n", + " # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory\n", + " def forward(self, x: torch.Tensor):\n", + " return x * torch.sigmoid(1.702 * x)\n", + "\n", + "\n", + "class LayerScale(nn.Module):\n", + " def __init__(self, dim, init_values=1e-5, inplace=False):\n", + " super().__init__()\n", + " self.inplace = inplace\n", + " self.gamma = nn.Parameter(init_values * torch.ones(dim))\n", + "\n", + " def forward(self, x):\n", + " return x.mul_(self.gamma) if self.inplace else x * self.gamma\n", + "\n", + "\n", + "class PatchDropout(nn.Module):\n", + " \"\"\"\n", + " https://arxiv.org/abs/2212.00794\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " prob: float = 0.5,\n", + " exclude_first_token: bool = True\n", + " ):\n", + " super().__init__()\n", + " assert 0 <= prob < 1.\n", + " self.prob = prob\n", + " self.exclude_first_token = exclude_first_token # exclude CLS token\n", + "\n", + " def forward(self, x):\n", + " if not self.training or self.prob == 0.:\n", + " return x\n", + "\n", + " if self.exclude_first_token:\n", + " cls_tokens, x = x[:, :1], x[:, 1:]\n", + " else:\n", + " cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])\n", + "\n", + " batch = x.size()[0]\n", + " num_tokens = x.size()[1]\n", + "\n", + " batch_indices = torch.arange(batch)\n", + " batch_indices = batch_indices[..., None]\n", + "\n", + " keep_prob = 1 - self.prob\n", + " num_patches_keep = max(1, int(num_tokens * keep_prob))\n", + "\n", + " rand = torch.randn(batch, num_tokens)\n", + " patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices\n", + "\n", + " x = x[batch_indices, patch_indices_keep]\n", + "\n", + " if self.exclude_first_token:\n", + " x = torch.cat((cls_tokens, x), dim=1)\n", + "\n", + " return x\n", + "\n", + "\n", + "class Attention(nn.Module):\n", + " def __init__(\n", + " self,\n", + " dim: int,\n", + " num_heads: int = 8,\n", + " qkv_bias: bool = True,\n", + " scaled_cosine: bool = False,\n", + " scale_heads: bool = False,\n", + " logit_scale_max: float = math.log(1. / 0.01),\n", + " batch_first: bool = True,\n", + " attn_drop: float = 0.,\n", + " proj_drop: float = 0.\n", + " ):\n", + " super().__init__()\n", + " self.scaled_cosine = scaled_cosine\n", + " self.scale_heads = scale_heads\n", + " assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n", + " self.num_heads = num_heads\n", + " self.head_dim = dim // num_heads\n", + " self.scale = self.head_dim ** -0.5\n", + " self.logit_scale_max = logit_scale_max\n", + " self.batch_first = batch_first\n", + " self.use_fsdpa = hasattr(nn.functional, 'scaled_dot_product_attention')\n", + "\n", + " # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original\n", + " self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)\n", + " if qkv_bias:\n", + " self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))\n", + " else:\n", + " self.in_proj_bias = None\n", + "\n", + " if self.scaled_cosine:\n", + " self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))\n", + " else:\n", + " self.logit_scale = None\n", + " self.attn_drop = nn.Dropout(attn_drop)\n", + " if self.scale_heads:\n", + " self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))\n", + " else:\n", + " self.head_scale = None\n", + " self.out_proj = nn.Linear(dim, dim)\n", + " self.out_drop = nn.Dropout(proj_drop)\n", + "\n", + " def forward(self, x, attn_mask: Optional[torch.Tensor] = None):\n", + " if self.batch_first:\n", + " x = x.transpose(0, 1)\n", + "\n", + " L, N, C = x.shape\n", + " q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)\n", + " q = q.reshape(L, N * self.num_heads, -1).transpose(0, 1)\n", + " k = k.reshape(L, N * self.num_heads, -1).transpose(0, 1)\n", + " v = v.reshape(L, N * self.num_heads, -1).transpose(0, 1)\n", + "\n", + " if attn_mask is not None and attn_mask.dtype == torch.bool:\n", + " new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)\n", + " new_attn_mask.masked_fill_(attn_mask, float(\"-inf\"))\n", + " attn_mask = new_attn_mask\n", + "\n", + " if self.logit_scale is not None:\n", + " attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))\n", + " logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()\n", + " attn = attn.view(N, self.num_heads, L, L) * logit_scale\n", + " attn = attn.view(-1, L, L)\n", + " if attn_mask is not None:\n", + " attn = attn + attn_mask\n", + " attn = attn.softmax(dim=-1)\n", + " attn = self.attn_drop(attn)\n", + " x = torch.bmm(attn, v)\n", + " else:\n", + " if self.use_fsdpa:\n", + " x = F.scaled_dot_product_attention(\n", + " q, k, v,\n", + " attn_mask=attn_mask,\n", + " dropout_p=self.attn_drop.p if self.training else 0.,\n", + " )\n", + " else:\n", + " q = q * self.scale\n", + " attn = torch.bmm(q, k.transpose(-1, -2))\n", + " if attn_mask is not None:\n", + " attn += attn_mask\n", + " attn = attn.softmax(dim=-1)\n", + " attn = self.attn_drop(attn)\n", + " x = torch.bmm(attn, v)\n", + "\n", + " if self.head_scale is not None:\n", + " x = x.view(N, self.num_heads, L, C) * self.head_scale\n", + " x = x.view(-1, L, C)\n", + "\n", + " x = x.transpose(0, 1).reshape(L, N, C)\n", + "\n", + " if self.batch_first:\n", + " x = x.transpose(0, 1)\n", + "\n", + " x = self.out_proj(x)\n", + " x = self.out_drop(x)\n", + " return x\n", + "\n", + "\n", + "class AttentionalPooler(nn.Module):\n", + " def __init__(\n", + " self,\n", + " d_model: int,\n", + " context_dim: int,\n", + " n_head: int = 8,\n", + " n_queries: int = 256,\n", + " norm_layer: Callable = LayerNorm,\n", + " ):\n", + " super().__init__()\n", + " self.query = nn.Parameter(torch.randn(n_queries, d_model))\n", + " self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim, batch_first=True)\n", + " self.ln_q = norm_layer(d_model)\n", + " self.ln_k = norm_layer(context_dim)\n", + "\n", + " def forward(self, x: torch.Tensor):\n", + " N = x.shape[0]\n", + " x = self.ln_k(x)\n", + " q = self.ln_q(self.query)\n", + " out = self.attn(q.unsqueeze(0).expand(N, -1, -1), x, x, need_weights=False)[0]\n", + " return out\n", + "\n", + "\n", + "class ResidualAttentionBlock(nn.Module):\n", + " def __init__(\n", + " self,\n", + " d_model: int,\n", + " n_head: int,\n", + " mlp_ratio: float = 4.0,\n", + " ls_init_value: float = None,\n", + " act_layer: Callable = nn.GELU,\n", + " norm_layer: Callable = LayerNorm,\n", + " is_cross_attention: bool = False,\n", + " batch_first: bool = True,\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.ln_1 = norm_layer(d_model)\n", + " self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=batch_first)\n", + " self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n", + " if is_cross_attention:\n", + " self.ln_1_kv = norm_layer(d_model)\n", + "\n", + " self.ln_2 = norm_layer(d_model)\n", + " mlp_width = int(d_model * mlp_ratio)\n", + " self.mlp = nn.Sequential(OrderedDict([\n", + " (\"c_fc\", nn.Linear(d_model, mlp_width)),\n", + " (\"gelu\", act_layer()),\n", + " (\"c_proj\", nn.Linear(mlp_width, d_model))\n", + " ]))\n", + " self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n", + "\n", + " def attention(\n", + " self,\n", + " q_x: torch.Tensor,\n", + " k_x: Optional[torch.Tensor] = None,\n", + " v_x: Optional[torch.Tensor] = None,\n", + " attn_mask: Optional[torch.Tensor] = None,\n", + " ):\n", + " k_x = k_x if k_x is not None else q_x\n", + " v_x = v_x if v_x is not None else q_x\n", + "\n", + " attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None\n", + " return self.attn(\n", + " q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask\n", + " )[0]\n", + "\n", + " def forward(\n", + " self,\n", + " q_x: torch.Tensor,\n", + " k_x: Optional[torch.Tensor] = None,\n", + " v_x: Optional[torch.Tensor] = None,\n", + " attn_mask: Optional[torch.Tensor] = None,\n", + " ):\n", + " k_x = self.ln_1_kv(k_x) if hasattr(self, \"ln_1_kv\") and k_x is not None else None\n", + " v_x = self.ln_1_kv(v_x) if hasattr(self, \"ln_1_kv\") and v_x is not None else None\n", + " x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask))\n", + " x = x + self.ls_2(self.mlp(self.ln_2(x)))\n", + " return x\n", + "\n", + "\n", + "class CustomResidualAttentionBlock(nn.Module):\n", + " def __init__(\n", + " self,\n", + " d_model: int,\n", + " n_head: int,\n", + " mlp_ratio: float = 4.0,\n", + " ls_init_value: float = None,\n", + " act_layer: Callable = nn.GELU,\n", + " norm_layer: Callable = LayerNorm,\n", + " scale_cosine_attn: bool = False,\n", + " scale_heads: bool = False,\n", + " scale_attn: bool = False,\n", + " scale_fc: bool = False,\n", + " batch_first: bool = True,\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.ln_1 = norm_layer(d_model)\n", + " self.attn = Attention(\n", + " d_model,\n", + " n_head,\n", + " scaled_cosine=scale_cosine_attn,\n", + " scale_heads=scale_heads,\n", + " batch_first=batch_first,\n", + " )\n", + " self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()\n", + " self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n", + "\n", + " self.ln_2 = norm_layer(d_model)\n", + " mlp_width = int(d_model * mlp_ratio)\n", + " self.mlp = nn.Sequential(OrderedDict([\n", + " (\"c_fc\", nn.Linear(d_model, mlp_width)),\n", + " (\"gelu\", act_layer()),\n", + " ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),\n", + " (\"c_proj\", nn.Linear(mlp_width, d_model))\n", + " ]))\n", + " self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()\n", + "\n", + " def get_reference_weight(self):\n", + " return self.mlp.c_fc.weight\n", + "\n", + " def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n", + " x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask)))\n", + " x = x + self.ls_2(self.mlp(self.ln_2(x)))\n", + " return x\n", + "\n", + "\n", + "class CustomTransformer(nn.Module):\n", + " \"\"\" A custom transformer that can use different block types. \"\"\"\n", + " def __init__(\n", + " self,\n", + " width: int,\n", + " layers: int,\n", + " heads: int,\n", + " mlp_ratio: float = 4.0,\n", + " ls_init_value: float = None,\n", + " act_layer: Callable = nn.GELU,\n", + " norm_layer: Callable = LayerNorm,\n", + " batch_first: bool = True,\n", + " block_types: Union[str, List[str]] = 'CustomResidualAttentionBlock',\n", + " ):\n", + " super().__init__()\n", + " self.width = width\n", + " self.layers = layers\n", + " self.batch_first = batch_first # run transformer stack in batch first (N, L, D)\n", + " self.grad_checkpointing = False\n", + "\n", + " if isinstance(block_types, str):\n", + " block_types = [block_types] * layers\n", + " assert len(block_types) == layers\n", + "\n", + " def _create_block(bt: str):\n", + " if bt == 'CustomResidualAttentionBlock':\n", + " return CustomResidualAttentionBlock(\n", + " width,\n", + " heads,\n", + " mlp_ratio=mlp_ratio,\n", + " ls_init_value=ls_init_value,\n", + " act_layer=act_layer,\n", + " norm_layer=norm_layer,\n", + " batch_first=batch_first,\n", + " )\n", + " else:\n", + " assert False\n", + "\n", + " self.resblocks = nn.ModuleList([\n", + " _create_block(bt)\n", + " for bt in block_types\n", + " ])\n", + "\n", + " def get_cast_dtype(self) -> torch.dtype:\n", + " weight = self.resblocks[0].get_reference_weight()\n", + " if hasattr(weight, 'int8_original_dtype'):\n", + " return weight.int8_original_dtype\n", + " return weight.dtype\n", + "\n", + " def forward_intermediates(\n", + " self,\n", + " x: torch.Tensor,\n", + " attn_mask: Optional[torch.Tensor] = None,\n", + " indices: Optional[Union[int, List[int]]] = None,\n", + " stop_early: bool = False,\n", + " ):\n", + " take_indices, max_index = feature_take_indices(len(self.resblocks), indices)\n", + "\n", + " if not self.batch_first:\n", + " x = x.transpose(0, 1).contiguous() # NLD -> LND\n", + "\n", + " intermediates = []\n", + " if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript\n", + " blocks = self.resblocks\n", + " else:\n", + " blocks = self.resblocks[:max_index + 1]\n", + " for i, blk in enumerate(blocks):\n", + " if self.grad_checkpointing and not torch.jit.is_scripting():\n", + " x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False)\n", + " else:\n", + " x = blk(x, attn_mask=attn_mask)\n", + "\n", + " if i in take_indices:\n", + " intermediates.append(x.transpose(0, 1) if not self.batch_first else x)\n", + "\n", + " if not self.batch_first:\n", + " x = x.transpose(0, 1) # LND -> NLD\n", + "\n", + " return x, intermediates\n", + "\n", + " def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1):\n", + " \"\"\" Prune layers not required for specified intermediates.\n", + " \"\"\"\n", + " take_indices, max_index = feature_take_indices(len(self.resblocks), indices)\n", + " self.resblocks = self.resblocks[:max_index + 1] # truncate blocks\n", + " return take_indices\n", + "\n", + " def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n", + " if not self.batch_first:\n", + " x = x.transpose(0, 1) # NLD -> LND\n", + "\n", + " for r in self.resblocks:\n", + " if self.grad_checkpointing and not torch.jit.is_scripting():\n", + " # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372\n", + " x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)\n", + " else:\n", + " x = r(x, attn_mask=attn_mask)\n", + "\n", + " if not self.batch_first:\n", + " x = x.transpose(0, 1) # NLD -> LND\n", + " return x\n", + "\n", + "\n", + "class Transformer(nn.Module):\n", + " def __init__(\n", + " self,\n", + " width: int,\n", + " layers: int,\n", + " heads: int,\n", + " mlp_ratio: float = 4.0,\n", + " ls_init_value: float = None,\n", + " act_layer: Callable = nn.GELU,\n", + " norm_layer: Callable = LayerNorm,\n", + " batch_first: bool = True,\n", + " ):\n", + " super().__init__()\n", + " self.width = width\n", + " self.layers = layers\n", + " self.batch_first = batch_first\n", + " self.grad_checkpointing = False\n", + "\n", + " self.resblocks = nn.ModuleList([\n", + " ResidualAttentionBlock(\n", + " width,\n", + " heads,\n", + " mlp_ratio,\n", + " ls_init_value=ls_init_value,\n", + " act_layer=act_layer,\n", + " norm_layer=norm_layer,\n", + " batch_first=batch_first,\n", + " )\n", + " for _ in range(layers)\n", + " ])\n", + "\n", + " def get_cast_dtype(self) -> torch.dtype:\n", + " if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):\n", + " return self.resblocks[0].mlp.c_fc.int8_original_dtype\n", + " return self.resblocks[0].mlp.c_fc.weight.dtype\n", + "\n", + " def forward_intermediates(\n", + " self,\n", + " x: torch.Tensor,\n", + " attn_mask: Optional[torch.Tensor] = None,\n", + " indices: Optional[Union[int, List[int]]] = None,\n", + " stop_early: bool = False,\n", + " ):\n", + " take_indices, max_index = feature_take_indices(len(self.resblocks), indices)\n", + "\n", + " if not self.batch_first:\n", + " x = x.transpose(0, 1).contiguous() # NLD -> LND\n", + "\n", + " intermediates = []\n", + " if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript\n", + " blocks = self.resblocks\n", + " else:\n", + " blocks = self.resblocks[:max_index + 1]\n", + " for i, blk in enumerate(blocks):\n", + " if self.grad_checkpointing and not torch.jit.is_scripting():\n", + " x = checkpoint(blk, x, None, None, attn_mask, use_reentrant=False)\n", + " else:\n", + " x = blk(x, attn_mask=attn_mask)\n", + "\n", + " if i in take_indices:\n", + " intermediates.append(x.transpose(0, 1) if not self.batch_first else x)\n", + "\n", + " if not self.batch_first:\n", + " x = x.transpose(0, 1) # LND -> NLD\n", + "\n", + " return x, intermediates\n", + "\n", + " def prune_intermediate_layers(self, indices: Union[int, List[int]] = 1):\n", + " \"\"\" Prune layers not required for specified intermediates.\n", + " \"\"\"\n", + " take_indices, max_index = feature_take_indices(len(self.resblocks), indices)\n", + " self.resblocks = self.resblocks[:max_index + 1] # truncate blocks\n", + " return take_indices\n", + "\n", + " def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):\n", + " if not self.batch_first:\n", + " x = x.transpose(0, 1).contiguous() # NLD -> LND\n", + "\n", + " for r in self.resblocks:\n", + " if self.grad_checkpointing and not torch.jit.is_scripting():\n", + " # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372\n", + " x = checkpoint(r, x, None, None, attn_mask, use_reentrant=False)\n", + " else:\n", + " x = r(x, attn_mask=attn_mask)\n", + "\n", + " if not self.batch_first:\n", + " x = x.transpose(0, 1) # LND -> NLD\n", + " return x\n", + "\n", + "\n", + "def _expand_token(token, batch_size: int):\n", + " return token.view(1, 1, -1).expand(batch_size, -1, -1)\n", + "\n", + "\n", + "class VisionTransformer(nn.Module):\n", + " output_tokens: torch.jit.Final[bool]\n", + "\n", + " def __init__(\n", + " self,\n", + " image_size: int,\n", + " patch_size: int,\n", + " width: int,\n", + " layers: int,\n", + " heads: int,\n", + " mlp_ratio: float,\n", + " ls_init_value: float = None,\n", + " attentional_pool: bool = False,\n", + " attn_pooler_queries: int = 256,\n", + " attn_pooler_heads: int = 8,\n", + " output_dim: int = 512,\n", + " patch_dropout: float = 0.,\n", + " no_ln_pre: bool = False,\n", + " pos_embed_type: str = 'learnable',\n", + " pool_type: str = 'tok',\n", + " final_ln_after_pool: bool = False,\n", + " act_layer: Callable = nn.GELU,\n", + " norm_layer: Callable = LayerNorm,\n", + " output_tokens: bool = False,\n", + " ):\n", + " super().__init__()\n", + " assert pool_type in ('tok', 'avg', 'none')\n", + " self.output_tokens = output_tokens\n", + " image_height, image_width = self.image_size = to_2tuple(image_size)\n", + " patch_height, patch_width = self.patch_size = to_2tuple(patch_size)\n", + " self.grid_size = (image_height // patch_height, image_width // patch_width)\n", + " self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled\n", + " self.output_dim = output_dim\n", + "\n", + " self.conv1 = nn.Conv2d(\n", + " in_channels=3,\n", + " out_channels=width,\n", + " kernel_size=patch_size,\n", + " stride=patch_size,\n", + " bias=False,\n", + " )\n", + "\n", + " # class embeddings and positional embeddings\n", + " scale = width ** -0.5\n", + " self.class_embedding = nn.Parameter(scale * torch.randn(width))\n", + " if pos_embed_type == 'learnable':\n", + " self.positional_embedding = nn.Parameter(\n", + " scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))\n", + " elif pos_embed_type == 'sin_cos_2d':\n", + " # fixed sin-cos embedding\n", + " assert self.grid_size[0] == self.grid_size[1],\\\n", + " 'currently sin cos 2d pos embedding only supports square input'\n", + " self.positional_embedding = nn.Parameter(\n", + " torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False)\n", + " pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True)\n", + " self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float())\n", + " else:\n", + " raise ValueError\n", + "\n", + " # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn\n", + " self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()\n", + "\n", + " self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width)\n", + " self.transformer = Transformer(\n", + " width,\n", + " layers,\n", + " heads,\n", + " mlp_ratio,\n", + " ls_init_value=ls_init_value,\n", + " act_layer=act_layer,\n", + " norm_layer=norm_layer,\n", + " )\n", + "\n", + " if attentional_pool:\n", + " if isinstance(attentional_pool, str):\n", + " self.attn_pool_type = attentional_pool\n", + " self.pool_type = 'none'\n", + " if attentional_pool in ('parallel', 'cascade'):\n", + " self.attn_pool = AttentionalPooler(\n", + " output_dim,\n", + " width,\n", + " n_head=attn_pooler_heads,\n", + " n_queries=attn_pooler_queries,\n", + " )\n", + " self.attn_pool_contrastive = AttentionalPooler(\n", + " output_dim,\n", + " width,\n", + " n_head=attn_pooler_heads,\n", + " n_queries=1,\n", + " )\n", + " else:\n", + " assert False\n", + " else:\n", + " self.attn_pool_type = ''\n", + " self.pool_type = pool_type\n", + " self.attn_pool = AttentionalPooler(\n", + " output_dim,\n", + " width,\n", + " n_head=attn_pooler_heads,\n", + " n_queries=attn_pooler_queries,\n", + " )\n", + " self.attn_pool_contrastive = None\n", + " pool_dim = output_dim\n", + " else:\n", + " self.attn_pool = None\n", + " pool_dim = width\n", + " self.pool_type = pool_type\n", + "\n", + " self.ln_post = norm_layer(pool_dim)\n", + " self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim))\n", + "\n", + " self.init_parameters()\n", + "\n", + " def lock(self, unlocked_groups: int = 0, freeze_bn_stats: bool = False):\n", + " for param in self.parameters():\n", + " param.requires_grad = False\n", + "\n", + " if unlocked_groups != 0:\n", + " groups = [\n", + " [\n", + " self.conv1,\n", + " self.class_embedding,\n", + " self.positional_embedding,\n", + " self.ln_pre,\n", + " ],\n", + " *self.transformer.resblocks[:-1],\n", + " [\n", + " self.transformer.resblocks[-1],\n", + " self.ln_post,\n", + " ],\n", + " self.proj,\n", + " ]\n", + "\n", + " def _unlock(x):\n", + " if isinstance(x, Sequence):\n", + " for g in x:\n", + " _unlock(g)\n", + " else:\n", + " if isinstance(x, torch.nn.Parameter):\n", + " x.requires_grad = True\n", + " else:\n", + " for p in x.parameters():\n", + " p.requires_grad = True\n", + "\n", + " _unlock(groups[-unlocked_groups:])\n", + "\n", + " def init_parameters(self):\n", + " # FIXME OpenAI CLIP did not define an init for the VisualTransformer\n", + " # TODO experiment if default PyTorch init, below, or alternate init is best.\n", + "\n", + " # nn.init.normal_(self.class_embedding, std=self.scale)\n", + " # nn.init.normal_(self.positional_embedding, std=self.scale)\n", + " #\n", + " # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n", + " # attn_std = self.transformer.width ** -0.5\n", + " # fc_std = (2 * self.transformer.width) ** -0.5\n", + " # for block in self.transformer.resblocks:\n", + " # nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n", + " # nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n", + " # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n", + " # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n", + " #\n", + " # if self.text_projection is not None:\n", + " # nn.init.normal_(self.text_projection, std=self.scale)\n", + " pass\n", + "\n", + " @torch.jit.ignore\n", + " def set_grad_checkpointing(self, enable: bool = True):\n", + " self.transformer.grad_checkpointing = enable\n", + "\n", + " @torch.jit.ignore\n", + " def no_weight_decay(self):\n", + " # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default\n", + " no_wd = {'positional_embedding', 'class_embedding'}\n", + " return no_wd\n", + "\n", + " def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n", + " if self.pool_type == 'avg':\n", + " pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:]\n", + " elif self.pool_type == 'tok':\n", + " pooled, tokens = x[:, 0], x[:, 1:]\n", + " else:\n", + " pooled = tokens = x\n", + "\n", + " return pooled, tokens\n", + "\n", + " def _embeds(self, x:torch.Tensor) -> torch.Tensor:\n", + " x = self.conv1(x) # shape = [*, dim, grid, grid]\n", + " x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]\n", + " x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]\n", + "\n", + " # class embeddings and positional embeddings\n", + " x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)\n", + " # shape = [*, grid ** 2 + 1, width]\n", + " x = x + self.positional_embedding.to(x.dtype)\n", + "\n", + " # patch dropout (if active)\n", + " x = self.patch_dropout(x)\n", + "\n", + " # apply norm before transformer\n", + " x = self.ln_pre(x)\n", + " return x\n", + "\n", + " def _pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n", + " if self.attn_pool is not None:\n", + " if self.attn_pool_contrastive is not None:\n", + " # This is untested, WIP pooling that should match paper\n", + " x = self.ln_post(x) # TBD LN first or separate one after each pool?\n", + " tokens = self.attn_pool(x)\n", + " if self.attn_pool_type == 'parallel':\n", + " pooled = self.attn_pool_contrastive(x)\n", + " else:\n", + " assert self.attn_pool_type == 'cascade'\n", + " pooled = self.attn_pool_contrastive(tokens)\n", + " else:\n", + " # this is the original OpenCLIP CoCa setup, does not match paper\n", + " x = self.attn_pool(x)\n", + " x = self.ln_post(x)\n", + " pooled, tokens = self._global_pool(x)\n", + " elif self.final_ln_after_pool:\n", + " pooled, tokens = self._global_pool(x)\n", + " pooled = self.ln_post(pooled)\n", + " else:\n", + " x = self.ln_post(x)\n", + " pooled, tokens = self._global_pool(x)\n", + "\n", + " return pooled, tokens\n", + "\n", + " def forward_intermediates(\n", + " self,\n", + " x: torch.Tensor,\n", + " indices: Optional[Union[int, List[int]]] = None,\n", + " stop_early: bool = False,\n", + " normalize_intermediates: bool = False,\n", + " intermediates_only: bool = False,\n", + " output_fmt: str = 'NCHW',\n", + " output_extra_tokens: bool = False,\n", + " ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", + " \"\"\" Forward features that returns intermediates.\n", + "\n", + " Args:\n", + " x: Input image tensor\n", + " indices: Take last n blocks if int, all if None, select matching indices if sequence\n", + " stop_early: Stop iterating over blocks when last desired intermediate hit\n", + " intermediates_only: Only return intermediate features\n", + " normalize_intermediates: Apply final norm layer to all intermediates\n", + " output_fmt: Shape of intermediate feature outputs\n", + " output_extra_tokens: Return both extra prefix class tokens\n", + " Returns:\n", + "\n", + " \"\"\"\n", + " assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'\n", + " reshape = output_fmt == 'NCHW'\n", + "\n", + " # forward pass\n", + " B, _, height, width = x.shape\n", + " x = self._embeds(x)\n", + " x, intermediates = self.transformer.forward_intermediates(\n", + " x,\n", + " indices=indices,\n", + " stop_early=stop_early,\n", + " )\n", + "\n", + " # process intermediates\n", + " if normalize_intermediates:\n", + " # apply final norm to all intermediates\n", + " intermediates = [self.ln_post(xi) for xi in intermediates]\n", + " num_prefix_tokens = 1 # one class token that's always there (as of now)\n", + " if num_prefix_tokens:\n", + " # split prefix (e.g. class, distill) and spatial feature tokens\n", + " prefix_tokens = [y[:, 0:num_prefix_tokens] for y in intermediates]\n", + " intermediates = [y[:, num_prefix_tokens:] for y in intermediates]\n", + " else:\n", + " prefix_tokens = None\n", + " if reshape:\n", + " # reshape to BCHW output format\n", + " H, W = height // self.patch_size[0], width // self.patch_size[1]\n", + " intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]\n", + "\n", + " output = {'image_intermediates': intermediates}\n", + " if prefix_tokens is not None and output_extra_tokens:\n", + " output['image_intermediates_prefix'] = prefix_tokens\n", + "\n", + " if intermediates_only:\n", + " return output\n", + "\n", + " pooled, _ = self._pool(x)\n", + "\n", + " if self.proj is not None:\n", + " pooled = pooled @ self.proj\n", + "\n", + " output['image_features'] = pooled\n", + "\n", + " return output\n", + "\n", + " def prune_intermediate_layers(\n", + " self,\n", + " indices: Union[int, List[int]] = 1,\n", + " prune_norm: bool = False,\n", + " prune_head: bool = True,\n", + " ):\n", + " \"\"\" Prune layers not required for specified intermediates.\n", + " \"\"\"\n", + " take_indices = self.transformer.prune_intermediate_layers(indices)\n", + " if prune_norm:\n", + " self.ln_post = nn.Identity()\n", + " if prune_head:\n", + " self.proj = None\n", + " return take_indices\n", + "\n", + " def forward(self, x: torch.Tensor):\n", + " x = self._embeds(x)\n", + " x = self.transformer(x)\n", + " pooled, tokens = self._pool(x)\n", + "\n", + " if self.proj is not None:\n", + " pooled = pooled @ self.proj\n", + "\n", + " if self.output_tokens:\n", + " return pooled, tokens\n", + " \n", + " return pooled\n", + "\n", + "\n", + "def text_global_pool(\n", + " x: torch.Tensor,\n", + " text: Optional[torch.Tensor] = None,\n", + " pool_type: str = 'argmax',\n", + ") -> torch.Tensor:\n", + " if pool_type == 'first':\n", + " pooled = x[:, 0]\n", + " elif pool_type == 'last':\n", + " pooled = x[:, -1]\n", + " elif pool_type == 'argmax':\n", + " # take features from the eot embedding (eot_token is the highest number in each sequence)\n", + " assert text is not None\n", + " pooled = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]\n", + " else:\n", + " pooled = x\n", + "\n", + " return pooled\n", + "\n", + "\n", + "class TextTransformer(nn.Module):\n", + " output_tokens: torch.jit.Final[bool]\n", + "\n", + " def __init__(\n", + " self,\n", + " context_length: int = 77,\n", + " vocab_size: int = 49408,\n", + " width: int = 512,\n", + " heads: int = 8,\n", + " layers: int = 12,\n", + " mlp_ratio: float = 4.0,\n", + " ls_init_value: float = None,\n", + " output_dim: Optional[int] = 512,\n", + " embed_cls: bool = False,\n", + " no_causal_mask: bool = False,\n", + " pad_id: int = 0,\n", + " pool_type: str = 'argmax',\n", + " proj_type: str = 'linear',\n", + " proj_bias: bool = False,\n", + " act_layer: Callable = nn.GELU,\n", + " norm_layer: Callable = LayerNorm,\n", + " output_tokens: bool = False,\n", + " ):\n", + " super().__init__()\n", + " assert pool_type in ('first', 'last', 'argmax', 'none')\n", + " self.output_tokens = output_tokens\n", + " self.num_pos = self.context_length = context_length\n", + " self.vocab_size = vocab_size\n", + " self.width = width\n", + " self.output_dim = output_dim\n", + " self.heads = heads\n", + " self.pad_id = pad_id\n", + " self.pool_type = pool_type\n", + "\n", + " self.token_embedding = nn.Embedding(vocab_size, width)\n", + " if embed_cls:\n", + " self.cls_emb = nn.Parameter(torch.empty(width))\n", + " self.num_pos += 1\n", + " else:\n", + " self.cls_emb = None\n", + " self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width))\n", + " self.transformer = Transformer(\n", + " width=width,\n", + " layers=layers,\n", + " heads=heads,\n", + " mlp_ratio=mlp_ratio,\n", + " ls_init_value=ls_init_value,\n", + " act_layer=act_layer,\n", + " norm_layer=norm_layer,\n", + " )\n", + " self.ln_final = norm_layer(width)\n", + "\n", + " if no_causal_mask:\n", + " self.attn_mask = None\n", + " else:\n", + " self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False)\n", + "\n", + " if proj_type == 'none' or not output_dim:\n", + " self.text_projection = None\n", + " else:\n", + " if proj_bias:\n", + " self.text_projection = nn.Linear(width, output_dim)\n", + " else:\n", + " self.text_projection = nn.Parameter(torch.empty(width, output_dim))\n", + "\n", + " self.init_parameters()\n", + "\n", + " def init_parameters(self):\n", + " nn.init.normal_(self.token_embedding.weight, std=0.02)\n", + " nn.init.normal_(self.positional_embedding, std=0.01)\n", + " if self.cls_emb is not None:\n", + " nn.init.normal_(self.cls_emb, std=0.01)\n", + "\n", + " proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n", + " attn_std = self.transformer.width ** -0.5\n", + " fc_std = (2 * self.transformer.width) ** -0.5\n", + " for block in self.transformer.resblocks:\n", + " nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n", + " nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n", + " nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n", + " nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n", + "\n", + " if self.text_projection is not None:\n", + " if isinstance(self.text_projection, nn.Linear):\n", + " nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5)\n", + " if self.text_projection.bias is not None:\n", + " nn.init.zeros_(self.text_projection.bias)\n", + " else:\n", + " nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)\n", + "\n", + " @torch.jit.ignore\n", + " def set_grad_checkpointing(self, enable=True):\n", + " self.transformer.grad_checkpointing = enable\n", + "\n", + " @torch.jit.ignore\n", + " def no_weight_decay(self):\n", + " # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default\n", + " no_wd = {'positional_embedding'}\n", + " if self.cls_emb is not None:\n", + " no_wd.add('cls_emb')\n", + " return no_wd\n", + "\n", + " def build_causal_mask(self):\n", + " # lazily create causal attention mask, with full attention between the tokens\n", + " # pytorch uses additive attention mask; fill with -inf\n", + " mask = torch.empty(self.num_pos, self.num_pos)\n", + " mask.fill_(float(\"-inf\"))\n", + " mask.triu_(1) # zero out the lower diagonal\n", + " return mask\n", + "\n", + " def build_cls_mask(self, text, cast_dtype: torch.dtype):\n", + " cls_mask = (text != self.pad_id).unsqueeze(1)\n", + " cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)\n", + " additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)\n", + " additive_mask.fill_(0)\n", + " additive_mask.masked_fill_(~cls_mask, float(\"-inf\"))\n", + " additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0)\n", + " return additive_mask\n", + "\n", + " def _embeds(self, text) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:\n", + " cast_dtype = self.transformer.get_cast_dtype()\n", + " seq_len = text.shape[1]\n", + " x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]\n", + " attn_mask = self.attn_mask\n", + " if self.cls_emb is not None:\n", + " seq_len += 1\n", + " x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1)\n", + " cls_mask = self.build_cls_mask(text, cast_dtype)\n", + " if attn_mask is not None:\n", + " attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len]\n", + " x = x + self.positional_embedding[:seq_len].to(cast_dtype)\n", + " return x, attn_mask\n", + "\n", + " def forward_intermediates(\n", + " self,\n", + " text: torch.Tensor,\n", + " indices: Optional[Union[int, List[int]]] = None,\n", + " stop_early: bool = False,\n", + " normalize_intermediates: bool = False,\n", + " intermediates_only: bool = False,\n", + " output_fmt: str = 'NCHW',\n", + " output_extra_tokens: bool = False,\n", + " ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", + " \"\"\" Forward features that returns intermediates.\n", + "\n", + " Args:\n", + " text: Input text ids\n", + " indices: Take last n blocks if int, all if None, select matching indices if sequence\n", + " stop_early: Stop iterating over blocks when last desired intermediate hit\n", + " normalize_intermediates: Apply norm layer to all intermediates\n", + " intermediates_only: Only return intermediate features\n", + " output_fmt: Shape of intermediate feature outputs\n", + " output_extra_tokens: Return both prefix and intermediate tokens\n", + " Returns:\n", + "\n", + " \"\"\"\n", + " assert output_fmt in ('NLC',), 'Output format must be NLC.'\n", + " # forward pass\n", + " x, attn_mask = self._embeds(text)\n", + " x, intermediates = self.transformer.forward_intermediates(\n", + " x,\n", + " attn_mask=attn_mask,\n", + " indices=indices,\n", + " stop_early=stop_early,\n", + " )\n", + "\n", + " # process intermediates\n", + " if normalize_intermediates:\n", + " # apply final norm to all intermediates\n", + " intermediates = [self.ln_final(xi) for xi in intermediates]\n", + "\n", + " output = {}\n", + "\n", + " if self.cls_emb is not None:\n", + " seq_intermediates = [xi[:, :-1] for xi in intermediates] # separate concat'd class token from sequence\n", + " if output_extra_tokens:\n", + " # return suffix class tokens separately\n", + " cls_intermediates = [xi[:, -1:] for xi in intermediates]\n", + " output['text_intermediates_suffix'] = cls_intermediates\n", + " intermediates = seq_intermediates\n", + " output['text_intermediates'] = intermediates\n", + "\n", + " if intermediates_only:\n", + " return output\n", + "\n", + " if self.cls_emb is not None:\n", + " # presence of appended cls embed (CoCa) overrides pool_type, always take last token\n", + " pooled = text_global_pool(x, pool_type='last')\n", + " pooled = self.ln_final(pooled) # final LN applied after pooling in this case\n", + " else:\n", + " x = self.ln_final(x)\n", + " pooled = text_global_pool(x, text, pool_type=self.pool_type)\n", + "\n", + " if self.text_projection is not None:\n", + " if isinstance(self.text_projection, nn.Linear):\n", + " pooled = self.text_projection(pooled)\n", + " else:\n", + " pooled = pooled @ self.text_projection\n", + "\n", + " output['text_features'] = pooled\n", + "\n", + " return output\n", + "\n", + " def prune_intermediate_layers(\n", + " self,\n", + " indices: Union[int, List[int]] = 1,\n", + " prune_norm: bool = False,\n", + " prune_head: bool = True,\n", + " ):\n", + " \"\"\" Prune layers not required for specified intermediates.\n", + " \"\"\"\n", + " take_indices = self.transformer.prune_intermediate_layers(indices)\n", + " if prune_norm:\n", + " self.ln_final = nn.Identity()\n", + " if prune_head:\n", + " self.text_projection = None\n", + " return take_indices\n", + "\n", + " def forward(self, text):\n", + " x, attn_mask = self._embeds(text)\n", + "\n", + " x = self.transformer(x, attn_mask=attn_mask)\n", + "\n", + " # x.shape = [batch_size, n_ctx, transformer.width]\n", + " if self.cls_emb is not None:\n", + " # presence of appended cls embed (CoCa) overrides pool_type, always take last token\n", + " pooled = text_global_pool(x, pool_type='last')\n", + " pooled = self.ln_final(pooled) # final LN applied after pooling in this case\n", + " tokens = x[:, :-1]\n", + " else:\n", + " x = self.ln_final(x)\n", + " pooled = text_global_pool(x, text, pool_type=self.pool_type)\n", + " tokens = x\n", + "\n", + " if self.text_projection is not None:\n", + " if isinstance(self.text_projection, nn.Linear):\n", + " pooled = self.text_projection(pooled)\n", + " else:\n", + " pooled = pooled @ self.text_projection\n", + "\n", + " if self.output_tokens:\n", + " return pooled, tokens\n", + "\n", + " return pooled\n", + "\n", + "\n", + "class MultimodalTransformer(Transformer):\n", + " def __init__(\n", + " self,\n", + " width: int,\n", + " layers: int,\n", + " heads: int,\n", + " context_length: int = 77,\n", + " mlp_ratio: float = 4.0,\n", + " ls_init_value: float = None,\n", + " act_layer: Callable = nn.GELU,\n", + " norm_layer: Callable = LayerNorm,\n", + " output_dim: int = 512,\n", + " batch_first: bool = True,\n", + " ):\n", + " super().__init__(\n", + " width=width,\n", + " layers=layers,\n", + " heads=heads,\n", + " mlp_ratio=mlp_ratio,\n", + " ls_init_value=ls_init_value,\n", + " act_layer=act_layer,\n", + " norm_layer=norm_layer,\n", + " batch_first=batch_first,\n", + " )\n", + " self.context_length = context_length\n", + " self.cross_attn = nn.ModuleList([\n", + " ResidualAttentionBlock(\n", + " width,\n", + " heads,\n", + " mlp_ratio,\n", + " ls_init_value=ls_init_value,\n", + " act_layer=act_layer,\n", + " norm_layer=norm_layer,\n", + " is_cross_attention=True,\n", + " batch_first=batch_first,\n", + " )\n", + " for _ in range(layers)\n", + " ])\n", + "\n", + " self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)\n", + "\n", + " self.ln_final = norm_layer(width)\n", + " self.text_projection = nn.Parameter(torch.empty(width, output_dim))\n", + "\n", + " def init_parameters(self):\n", + " proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n", + " attn_std = self.transformer.width ** -0.5\n", + " fc_std = (2 * self.transformer.width) ** -0.5\n", + " for block in self.transformer.resblocks:\n", + " nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n", + " nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n", + " nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n", + " nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n", + " for block in self.transformer.cross_attn:\n", + " nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n", + " nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n", + " nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n", + " nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n", + "\n", + " if self.text_projection is not None:\n", + " nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)\n", + "\n", + " def build_attention_mask(self):\n", + " # lazily create causal attention mask, with full attention between the tokens\n", + " # pytorch uses additive attention mask; fill with -inf\n", + " mask = torch.empty(self.context_length, self.context_length)\n", + " mask.fill_(float(\"-inf\"))\n", + " mask.triu_(1) # zero out the lower diagonal\n", + " return mask\n", + "\n", + " def forward_intermediates(\n", + " self,\n", + " x: torch.Tensor,\n", + " attn_mask: Optional[torch.Tensor] = None,\n", + " indices: Optional[Union[int, List[int]]] = None,\n", + " stop_early: bool = False,\n", + " ):\n", + " assert False, \"Not currently implemented for MultimodalTransformer w/ xattn\"\n", + "\n", + " def forward(self, image_embs, text_embs):\n", + " seq_len = text_embs.shape[1]\n", + " if not self.batch_first:\n", + " image_embs = image_embs.permute(1, 0, 2) # NLD -> LND\n", + " text_embs = text_embs.permute(1, 0, 2) # NLD -> LND\n", + "\n", + " for resblock, cross_attn in zip(self.resblocks, self.cross_attn):\n", + " if self.grad_checkpointing and not torch.jit.is_scripting():\n", + " # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372\n", + " text_embs = checkpoint(\n", + " resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len], use_reentrant=False)\n", + " text_embs = checkpoint(\n", + " cross_attn, text_embs, image_embs, image_embs, None, use_reentrant=False)\n", + " else:\n", + " text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len])\n", + " text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs)\n", + "\n", + " if not self.batch_first:\n", + " text_embs = text_embs.permute(1, 0, 2) # LND -> NLD\n", + "\n", + " out = self.ln_final(text_embs)\n", + " if self.text_projection is not None:\n", + " out = out @ self.text_projection\n", + "\n", + " return out\n", + "\n", + " @torch.jit.ignore\n", + " def set_grad_checkpointing(self, enable=True):\n", + " self.grad_checkpointing = enable\n", + "\n", + "\n", + "\n", + "@dataclass\n", + "class CLIPVisionCfg:\n", + " layers: Union[Tuple[int, int, int, int], int] = 12\n", + " width: int = 768\n", + " head_width: int = 64\n", + " mlp_ratio: float = 4.0\n", + " patch_size: int = 16\n", + " image_size: Union[Tuple[int, int], int] = 224\n", + "\n", + " ls_init_value: Optional[float] = None # layer scale initial value\n", + " patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results\n", + " attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type)\n", + " attn_pooler_queries: int = 256 # n_queries for attentional pooler\n", + " attn_pooler_heads: int = 8 # n heads for attentional_pooling\n", + " no_ln_pre: bool = False # disable pre transformer LayerNorm\n", + " pos_embed_type: str = 'learnable'\n", + " final_ln_after_pool: bool = False # apply final LayerNorm after pooling\n", + " pool_type: str = 'tok'\n", + " output_tokens: bool = False\n", + " act_kwargs: Optional[dict] = None\n", + " norm_kwargs: Optional[dict] = None\n", + "\n", + " timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size\n", + " timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model\n", + " timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')\n", + " timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')\n", + " timm_proj_bias: bool = False # enable bias final projection\n", + " timm_drop: float = 0. # head dropout\n", + " timm_drop_path: Optional[float] = None # backbone stochastic depth\n", + "\n", + "\n", + "@dataclass\n", + "class CLIPTextCfg:\n", + " context_length: int = 77\n", + " vocab_size: int = 49408\n", + " hf_tokenizer_name: Optional[str] = None\n", + " tokenizer_kwargs: Optional[dict] = None\n", + "\n", + " width: int = 512\n", + " heads: int = 8\n", + " layers: int = 12\n", + " mlp_ratio: float = 4.0\n", + " ls_init_value: Optional[float] = None # layer scale initial value\n", + " embed_cls: bool = False\n", + " pad_id: int = 0\n", + " no_causal_mask: bool = False # disable causal masking\n", + " final_ln_after_pool: bool = False # apply final LayerNorm after pooling\n", + " pool_type: str = 'argmax'\n", + " proj_bias: bool = False\n", + " proj_type: str = 'linear' # control final text projection, 'none' forces no projection\n", + " output_tokens: bool = False\n", + " act_kwargs: dict = None\n", + " norm_kwargs: dict = None\n", + "\n", + " # HuggingFace specific text tower config\n", + " hf_model_name: Optional[str] = None\n", + " hf_model_pretrained: bool = True\n", + " hf_proj_type: str = 'mlp'\n", + " hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models\n", + "\n", + "\n", + "def get_cast_dtype(precision: str):\n", + " cast_dtype = None\n", + " if precision == 'bf16':\n", + " cast_dtype = torch.bfloat16\n", + " elif precision == 'fp16':\n", + " cast_dtype = torch.float16\n", + " return cast_dtype\n", + "\n", + "\n", + "def get_input_dtype(precision: str):\n", + " input_dtype = None\n", + " if precision in ('bf16', 'pure_bf16'):\n", + " input_dtype = torch.bfloat16\n", + " elif precision in ('fp16', 'pure_fp16'):\n", + " input_dtype = torch.float16\n", + " return input_dtype\n", + "\n", + "\n", + "def _build_vision_tower(\n", + " embed_dim: int,\n", + " vision_cfg: CLIPVisionCfg,\n", + " quick_gelu: bool = False,\n", + " cast_dtype: Optional[torch.dtype] = None\n", + "):\n", + " if isinstance(vision_cfg, dict):\n", + " vision_cfg = CLIPVisionCfg(**vision_cfg)\n", + "\n", + " # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more\n", + " # memory efficient in recent PyTorch releases (>= 1.10).\n", + " # NOTE: timm models always use native GELU regardless of quick_gelu flag.\n", + " act_layer = QuickGELU if quick_gelu else nn.GELU\n", + "\n", + " if vision_cfg.timm_model_name:\n", + " visual = TimmModel(\n", + " vision_cfg.timm_model_name,\n", + " pretrained=vision_cfg.timm_model_pretrained,\n", + " pool=vision_cfg.timm_pool,\n", + " proj=vision_cfg.timm_proj,\n", + " proj_bias=vision_cfg.timm_proj_bias,\n", + " drop=vision_cfg.timm_drop,\n", + " drop_path=vision_cfg.timm_drop_path,\n", + " patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None,\n", + " embed_dim=embed_dim,\n", + " image_size=vision_cfg.image_size,\n", + " )\n", + " elif isinstance(vision_cfg.layers, (tuple, list)):\n", + " vision_heads = vision_cfg.width * 32 // vision_cfg.head_width\n", + " visual = ModifiedResNet(\n", + " layers=vision_cfg.layers,\n", + " output_dim=embed_dim,\n", + " heads=vision_heads,\n", + " image_size=vision_cfg.image_size,\n", + " width=vision_cfg.width,\n", + " )\n", + " else:\n", + " vision_heads = vision_cfg.width // vision_cfg.head_width\n", + " norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\n", + " if vision_cfg.norm_kwargs:\n", + " norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs)\n", + " if vision_cfg.act_kwargs is not None:\n", + " act_layer = partial(act_layer, **vision_cfg.act_kwargs)\n", + "\n", + " visual = VisionTransformer(\n", + " image_size=vision_cfg.image_size,\n", + " patch_size=vision_cfg.patch_size,\n", + " width=vision_cfg.width,\n", + " layers=vision_cfg.layers,\n", + " heads=vision_heads,\n", + " mlp_ratio=vision_cfg.mlp_ratio,\n", + " ls_init_value=vision_cfg.ls_init_value,\n", + " patch_dropout=vision_cfg.patch_dropout,\n", + " attentional_pool=vision_cfg.attentional_pool,\n", + " attn_pooler_queries=vision_cfg.attn_pooler_queries,\n", + " attn_pooler_heads=vision_cfg.attn_pooler_heads,\n", + " pos_embed_type=vision_cfg.pos_embed_type,\n", + " no_ln_pre=vision_cfg.no_ln_pre,\n", + " final_ln_after_pool=vision_cfg.final_ln_after_pool,\n", + " pool_type=vision_cfg.pool_type,\n", + " output_tokens=vision_cfg.output_tokens,\n", + " output_dim=embed_dim,\n", + " act_layer=act_layer,\n", + " norm_layer=norm_layer,\n", + " )\n", + "\n", + " return visual\n", + "\n", + "\n", + "def _build_text_tower(\n", + " embed_dim: int,\n", + " text_cfg: CLIPTextCfg,\n", + " quick_gelu: bool = False,\n", + " cast_dtype: Optional[torch.dtype] = None,\n", + "):\n", + " if isinstance(text_cfg, dict):\n", + " text_cfg = CLIPTextCfg(**text_cfg)\n", + "\n", + " if text_cfg.hf_model_name:\n", + " text = HFTextEncoder(\n", + " text_cfg.hf_model_name,\n", + " output_dim=embed_dim,\n", + " proj_type=text_cfg.hf_proj_type,\n", + " pooler_type=text_cfg.hf_pooler_type,\n", + " pretrained=text_cfg.hf_model_pretrained,\n", + " output_tokens=text_cfg.output_tokens,\n", + " )\n", + " else:\n", + " act_layer = QuickGELU if quick_gelu else nn.GELU\n", + " norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\n", + " if text_cfg.norm_kwargs:\n", + " norm_layer = partial(norm_layer, **text_cfg.norm_kwargs)\n", + " if text_cfg.act_kwargs is not None:\n", + " act_layer = partial(act_layer, **text_cfg.act_kwargs)\n", + "\n", + " text = TextTransformer(\n", + " context_length=text_cfg.context_length,\n", + " vocab_size=text_cfg.vocab_size,\n", + " width=text_cfg.width,\n", + " heads=text_cfg.heads,\n", + " layers=text_cfg.layers,\n", + " mlp_ratio=text_cfg.mlp_ratio,\n", + " ls_init_value=text_cfg.ls_init_value,\n", + " output_dim=embed_dim,\n", + " embed_cls=text_cfg.embed_cls,\n", + " no_causal_mask=text_cfg.no_causal_mask,\n", + " pad_id=text_cfg.pad_id,\n", + " pool_type=text_cfg.pool_type,\n", + " proj_type=text_cfg.proj_type,\n", + " proj_bias=text_cfg.proj_bias,\n", + " output_tokens=text_cfg.output_tokens,\n", + " act_layer=act_layer,\n", + " norm_layer=norm_layer,\n", + " )\n", + " return text\n", + "\n", + "\n", + "class CLIP(nn.Module):\n", + " output_dict: torch.jit.Final[bool]\n", + "\n", + " def __init__(\n", + " self,\n", + " embed_dim: int,\n", + " vision_cfg: CLIPVisionCfg,\n", + " text_cfg: CLIPTextCfg,\n", + " quick_gelu: bool = False,\n", + " init_logit_scale: float = np.log(1 / 0.07),\n", + " init_logit_bias: Optional[float] = None,\n", + " nonscalar_logit_scale: bool = False,\n", + " cast_dtype: Optional[torch.dtype] = None,\n", + " output_dict: bool = False,\n", + " ):\n", + " super().__init__()\n", + " self.output_dict = output_dict\n", + "\n", + " self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)\n", + "\n", + " text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)\n", + " self.transformer = text.transformer\n", + " self.context_length = text.context_length\n", + " self.vocab_size = text.vocab_size\n", + " self.token_embedding = text.token_embedding\n", + " self.positional_embedding = text.positional_embedding\n", + " self.ln_final = text.ln_final\n", + " self.text_projection = text.text_projection\n", + " self.text_pool_type = text.pool_type\n", + " self.register_buffer('attn_mask', text.attn_mask, persistent=False)\n", + "\n", + " lshape = [1] if nonscalar_logit_scale else []\n", + " self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)\n", + " if init_logit_bias is not None:\n", + " self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)\n", + " else:\n", + " self.logit_bias = None\n", + "\n", + " def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):\n", + " # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\n", + " self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)\n", + "\n", + " @torch.jit.ignore\n", + " def set_grad_checkpointing(self, enable=True):\n", + " self.visual.set_grad_checkpointing(enable)\n", + " self.transformer.grad_checkpointing = enable\n", + "\n", + " @torch.jit.ignore\n", + " def no_weight_decay(self):\n", + " # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default\n", + " no_wd = {'positional_embedding'}\n", + " if hasattr(self.visual, 'no_weight_decay'):\n", + " for n in self.visual.no_weight_decay():\n", + " no_wd.add('visual.' + n)\n", + " return no_wd\n", + "\n", + " def encode_image(self, image, normalize: bool = False):\n", + " features = self.visual(image)\n", + " return F.normalize(features, dim=-1) if normalize else features\n", + "\n", + " def encode_text(self, text, normalize: bool = False):\n", + " cast_dtype = self.transformer.get_cast_dtype()\n", + "\n", + " x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]\n", + "\n", + " x = x + self.positional_embedding.to(cast_dtype)\n", + " x = self.transformer(x, attn_mask=self.attn_mask)\n", + " x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]\n", + " x = text_global_pool(x, text, self.text_pool_type)\n", + " if self.text_projection is not None:\n", + " if isinstance(self.text_projection, nn.Linear):\n", + " x = self.text_projection(x)\n", + " else:\n", + " x = x @ self.text_projection\n", + "\n", + " return F.normalize(x, dim=-1) if normalize else x\n", + "\n", + " def get_logits(self, image, text):\n", + " image_features = self.encode_image(image, normalize=True)\n", + " text_features = self.encode_text(text, normalize=True)\n", + " image_logits = self.logit_scale.exp() * image_features @ text_features.T\n", + " if self.logit_bias is not None:\n", + " image_logits += self.logit_bias\n", + " text_logits = image_logits.T\n", + " return image_logits, text_logits\n", + "\n", + " def forward_intermediates(\n", + " self,\n", + " image: Optional[torch.Tensor] = None,\n", + " text: Optional[torch.Tensor] = None,\n", + " image_indices: Optional[Union[int, List[int]]] = None,\n", + " text_indices: Optional[Union[int, List[int]]] = None,\n", + " stop_early: bool = False,\n", + " normalize: bool = True,\n", + " normalize_intermediates: bool = False,\n", + " intermediates_only: bool = False,\n", + " image_output_fmt: str = 'NCHW',\n", + " image_output_extra_tokens: bool = False,\n", + " text_output_fmt: str = 'NLC',\n", + " text_output_extra_tokens: bool = False,\n", + " output_logits: bool = False,\n", + " output_logit_scale_bias: bool = False,\n", + " ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", + " \"\"\" Forward features that returns intermediates.\n", + "\n", + " Args:\n", + " image: Input image tensor\n", + " text: Input text tensor\n", + " image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence\n", + " text_indices: Take last n blocks if int, all if None, select matching indices if sequence\n", + " stop_early: Stop iterating over blocks when last desired intermediate hit\n", + " normalize_intermediates: Apply final norm layer to all intermediates\n", + " normalize: L2 Normalize final features\n", + " intermediates_only: Only return intermediate features, do not return final features\n", + " image_output_fmt: Shape of intermediate image feature outputs\n", + " image_output_extra_tokens: Return both prefix and spatial intermediate tokens\n", + " text_output_fmt: Shape of intermediate text feature outputs (ignored for this model)\n", + " text_output_extra_tokens: Return both prefix and spatial intermediate tokens (ignored for this model)\n", + " output_logits: Include logits in output\n", + " output_logit_scale_bias: Include the logit scale bias in the output\n", + " Returns:\n", + "\n", + " \"\"\"\n", + " output = {}\n", + " if intermediates_only:\n", + " # intermediates only disables final feature normalization, and include logits\n", + " normalize = False\n", + " output_logits = False\n", + " if output_logits:\n", + " assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'\n", + "\n", + " if image is not None:\n", + " image_output = self.visual.forward_intermediates(\n", + " image,\n", + " indices=image_indices,\n", + " stop_early=stop_early,\n", + " normalize_intermediates=normalize_intermediates,\n", + " intermediates_only=intermediates_only,\n", + " output_fmt=image_output_fmt,\n", + " output_extra_tokens=image_output_extra_tokens,\n", + " )\n", + " if normalize and \"image_features\" in image_output:\n", + " image_output[\"image_features\"] = F.normalize(image_output[\"image_features\"], dim=-1)\n", + " output.update(image_output)\n", + "\n", + " if text is not None:\n", + " cast_dtype = self.transformer.get_cast_dtype()\n", + " x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]\n", + " x = x + self.positional_embedding.to(cast_dtype)\n", + " x, intermediates = self.transformer.forward_intermediates(\n", + " x,\n", + " attn_mask=self.attn_mask,\n", + " indices=text_indices\n", + " )\n", + " if normalize_intermediates:\n", + " intermediates = [self.ln_final(xi) for xi in intermediates]\n", + "\n", + " # NOTE this model doesn't support cls embed in text transformer, no need for extra intermediate tokens\n", + " output[\"text_intermediates\"] = intermediates\n", + "\n", + " if not intermediates_only:\n", + " x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]\n", + " x = text_global_pool(x, text, self.text_pool_type)\n", + " if self.text_projection is not None:\n", + " if isinstance(self.text_projection, nn.Linear):\n", + " x = self.text_projection(x)\n", + " else:\n", + " x = x @ self.text_projection\n", + " if normalize:\n", + " x = F.normalize(x, dim=-1)\n", + " output[\"text_features\"] = x\n", + "\n", + " logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None\n", + "\n", + " if output_logits:\n", + " image_logits = logit_scale_exp * output[\"image_features\"] @ output[\"text_features\"].T\n", + " if self.logit_bias is not None:\n", + " image_logits += self.logit_bias\n", + " text_logits = image_logits.T\n", + " output[\"image_logits\"] = image_logits\n", + " output[\"text_logits\"] = text_logits\n", + "\n", + " if output_logit_scale_bias:\n", + " output[\"logit_scale\"] = logit_scale_exp\n", + " if self.logit_bias is not None:\n", + " output['logit_bias'] = self.logit_bias\n", + "\n", + " return output\n", + "\n", + " def forward(\n", + " self,\n", + " image: Optional[torch.Tensor] = None,\n", + " text: Optional[torch.Tensor] = None,\n", + " ):\n", + " image_features = self.encode_image(image, normalize=True) if image is not None else None\n", + " text_features = self.encode_text(text, normalize=True) if text is not None else None\n", + "\n", + " if self.output_dict:\n", + " out_dict = {\n", + " \"image_features\": image_features,\n", + " \"text_features\": text_features,\n", + " \"logit_scale\": self.logit_scale.exp()\n", + " }\n", + " if self.logit_bias is not None:\n", + " out_dict['logit_bias'] = self.logit_bias\n", + " return out_dict\n", + "\n", + " if self.logit_bias is not None:\n", + " return image_features, text_features, self.logit_scale.exp(), self.logit_bias\n", + " return image_features, text_features, self.logit_scale.exp()\n", + "\n", + "\n", + "class CustomTextCLIP(nn.Module):\n", + " output_dict: torch.jit.Final[bool]\n", + "\n", + " def __init__(\n", + " self,\n", + " embed_dim: int,\n", + " vision_cfg: CLIPVisionCfg,\n", + " text_cfg: CLIPTextCfg,\n", + " quick_gelu: bool = False,\n", + " init_logit_scale: float = np.log(1 / 0.07),\n", + " init_logit_bias: Optional[float] = None,\n", + " nonscalar_logit_scale: bool = False,\n", + " cast_dtype: Optional[torch.dtype] = None,\n", + " output_dict: bool = False,\n", + " ):\n", + " super().__init__()\n", + " self.output_dict = output_dict\n", + " self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)\n", + " self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)\n", + " self.context_length = self.text.context_length\n", + " self.vocab_size = self.text.vocab_size\n", + "\n", + " lshape = [1] if nonscalar_logit_scale else []\n", + " self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)\n", + " if init_logit_bias is not None:\n", + " self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)\n", + " else:\n", + " self.logit_bias = None\n", + "\n", + " def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):\n", + " # lock image tower as per LiT - https://arxiv.org/abs/2111.07991\n", + " self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)\n", + "\n", + " def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):\n", + " self.text.lock(unlocked_layers, freeze_layer_norm)\n", + "\n", + " @torch.jit.ignore\n", + " def set_grad_checkpointing(self, enable=True):\n", + " self.visual.set_grad_checkpointing(enable)\n", + " self.text.set_grad_checkpointing(enable)\n", + "\n", + " @torch.jit.ignore\n", + " def no_weight_decay(self):\n", + " # for timm optimizers, 1d params like logit_scale, logit_bias, ln/bn scale, biases are excluded by default\n", + " no_wd = set()\n", + " if hasattr(self.visual, 'no_weight_decay'):\n", + " for n in self.visual.no_weight_decay():\n", + " no_wd.add('visual.' + n)\n", + " if hasattr(self.text, 'no_weight_decay'):\n", + " for n in self.visual.no_weight_decay():\n", + " no_wd.add('text.' + n)\n", + " return no_wd\n", + "\n", + " def encode_image(self, image, normalize: bool = False):\n", + " features = self.visual(image)\n", + " return F.normalize(features, dim=-1) if normalize else features\n", + "\n", + " def encode_text(self, text, normalize: bool = False):\n", + " features = self.text(text)\n", + " return F.normalize(features, dim=-1) if normalize else features\n", + "\n", + " def get_logits(self, image, text):\n", + " image_features = self.encode_image(image, normalize=True)\n", + " text_features = self.encode_text(text, normalize=True)\n", + " image_logits = self.logit_scale.exp() * image_features @ text_features.T\n", + " if self.logit_bias is not None:\n", + " image_logits += self.logit_bias\n", + " text_logits = image_logits.T\n", + " return image_logits, text_logits\n", + "\n", + " def forward_intermediates(\n", + " self,\n", + " image: Optional[torch.Tensor] = None,\n", + " text: Optional[torch.Tensor] = None,\n", + " image_indices: Optional[Union[int, List[int]]] = None,\n", + " text_indices: Optional[Union[int, List[int]]] = None,\n", + " stop_early: bool = False,\n", + " normalize: bool = True,\n", + " normalize_intermediates: bool = False,\n", + " intermediates_only: bool = False,\n", + " image_output_fmt: str = 'NCHW',\n", + " image_output_extra_tokens: bool = False,\n", + " text_output_fmt: str = 'NLC',\n", + " text_output_extra_tokens: bool = False,\n", + " output_logits: bool = False,\n", + " output_logit_scale_bias: bool = False,\n", + " ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", + " \"\"\" Forward features that returns intermediates.\n", + "\n", + " Args:\n", + " image: Input image tensor\n", + " text: Input text tensor\n", + " image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence\n", + " text_indices: Take last n blocks if int, all if None, select matching indices if sequence\n", + " stop_early: Stop iterating over blocks when last desired intermediate hit\n", + " normalize: L2 Normalize final image and text features (if present)\n", + " normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)\n", + " intermediates_only: Only return intermediate features, do not return final features\n", + " image_output_fmt: Shape of intermediate image feature outputs\n", + " image_output_extra_tokens: Return both prefix and spatial intermediate tokens\n", + " text_output_fmt: Shape of intermediate text feature outputs\n", + " text_output_extra_tokens: Return both prefix and spatial intermediate tokens\n", + " output_logits: Include logits in output\n", + " output_logit_scale_bias: Include the logit scale bias in the output\n", + " Returns:\n", + "\n", + " \"\"\"\n", + " output = {}\n", + " if intermediates_only:\n", + " # intermediates only disables final feature normalization, and include logits\n", + " normalize = False\n", + " output_logits = False\n", + " if output_logits:\n", + " assert image is not None and text is not None, 'Both image and text inputs are required to compute logits'\n", + "\n", + " if image is not None:\n", + " image_output = self.visual.forward_intermediates(\n", + " image,\n", + " indices=image_indices,\n", + " stop_early=stop_early,\n", + " normalize_intermediates=normalize_intermediates,\n", + " intermediates_only=intermediates_only,\n", + " output_fmt=image_output_fmt,\n", + " output_extra_tokens=image_output_extra_tokens,\n", + " )\n", + " if normalize and \"image_features\" in image_output:\n", + " image_output[\"image_features\"] = F.normalize(image_output[\"image_features\"], dim=-1)\n", + " output.update(image_output)\n", + "\n", + " if text is not None:\n", + " text_output = self.text.forward_intermediates(\n", + " text,\n", + " indices=text_indices,\n", + " stop_early=stop_early,\n", + " normalize_intermediates=normalize_intermediates,\n", + " intermediates_only=intermediates_only,\n", + " output_fmt=text_output_fmt,\n", + " output_extra_tokens=text_output_extra_tokens,\n", + " )\n", + " if normalize and \"text_features\" in text_output:\n", + " text_output[\"text_features\"] = F.normalize(text_output[\"text_features\"], dim=-1)\n", + " output.update(text_output)\n", + "\n", + " logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None\n", + "\n", + " if output_logits:\n", + " image_logits = logit_scale_exp * output[\"image_features\"] @ output[\"text_features\"].T\n", + " if self.logit_bias is not None:\n", + " image_logits += self.logit_bias\n", + " text_logits = image_logits.T\n", + " output[\"image_logits\"] = image_logits\n", + " output[\"text_logits\"] = text_logits\n", + "\n", + " if output_logit_scale_bias:\n", + " output[\"logit_scale\"] = logit_scale_exp\n", + " if self.logit_bias is not None:\n", + " output['logit_bias'] = self.logit_bias\n", + "\n", + " return output\n", + "\n", + " def forward(\n", + " self,\n", + " image: Optional[torch.Tensor] = None,\n", + " text: Optional[torch.Tensor] = None,\n", + " ):\n", + " image_features = self.encode_image(image, normalize=True) if image is not None else None\n", + " text_features = self.encode_text(text, normalize=True) if text is not None else None\n", + "\n", + " if self.output_dict:\n", + " out_dict = {\n", + " \"image_features\": image_features,\n", + " \"text_features\": text_features,\n", + " \"logit_scale\": self.logit_scale.exp()\n", + " }\n", + " if self.logit_bias is not None:\n", + " out_dict['logit_bias'] = self.logit_bias\n", + " return out_dict\n", + "\n", + " if self.logit_bias is not None:\n", + " return image_features, text_features, self.logit_scale.exp(), self.logit_bias\n", + " return image_features, text_features, self.logit_scale.exp()\n", + "\n", + "\n", + "def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):\n", + " \"\"\"Convert applicable model parameters to low-precision (bf16 or fp16)\"\"\"\n", + "\n", + " def _convert_weights(l):\n", + " if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):\n", + " l.weight.data = l.weight.data.to(dtype)\n", + " if l.bias is not None:\n", + " l.bias.data = l.bias.data.to(dtype)\n", + "\n", + " if isinstance(l, (nn.MultiheadAttention, Attention)):\n", + " for attr in [*[f\"{s}_proj_weight\" for s in [\"in\", \"q\", \"k\", \"v\"]], \"in_proj_bias\", \"bias_k\", \"bias_v\"]:\n", + " tensor = getattr(l, attr)\n", + " if tensor is not None:\n", + " tensor.data = tensor.data.to(dtype)\n", + "\n", + " if isinstance(l, (CLIP, TextTransformer)):\n", + " # convert text nn.Parameter projections\n", + " attr = getattr(l, \"text_projection\", None)\n", + " if attr is not None:\n", + " attr.data = attr.data.to(dtype)\n", + "\n", + " if isinstance(l, VisionTransformer):\n", + " # convert vision nn.Parameter projections\n", + " attr = getattr(l, \"proj\", None)\n", + " if attr is not None:\n", + " attr.data = attr.data.to(dtype)\n", + "\n", + " model.apply(_convert_weights)\n", + "\n", + "\n", + "convert_weights_to_fp16 = convert_weights_to_lp # backwards compat\n", + "\n", + "\n", + "# used to maintain checkpoint compatibility\n", + "def convert_to_custom_text_state_dict(state_dict: dict):\n", + " if 'text_projection' in state_dict:\n", + " # old format state_dict, move text tower -> .text\n", + " new_state_dict = {}\n", + " for k, v in state_dict.items():\n", + " if any(k.startswith(p) for p in (\n", + " 'text_projection',\n", + " 'positional_embedding',\n", + " 'token_embedding',\n", + " 'transformer',\n", + " 'ln_final',\n", + " )):\n", + " k = 'text.' + k\n", + " new_state_dict[k] = v\n", + " return new_state_dict\n", + " return state_dict\n", + "\n", + "\n", + "def build_model_from_openai_state_dict(\n", + " state_dict: dict,\n", + " quick_gelu=True,\n", + " cast_dtype=torch.float16,\n", + "):\n", + " vit = \"visual.proj\" in state_dict\n", + "\n", + " if vit:\n", + " vision_width = state_dict[\"visual.conv1.weight\"].shape[0]\n", + " vision_layers = len(\n", + " [k for k in state_dict.keys() if k.startswith(\"visual.\") and k.endswith(\".attn.in_proj_weight\")])\n", + " vision_patch_size = state_dict[\"visual.conv1.weight\"].shape[-1]\n", + " grid_size = round((state_dict[\"visual.positional_embedding\"].shape[0] - 1) ** 0.5)\n", + " image_size = vision_patch_size * grid_size\n", + " else:\n", + " counts: list = [\n", + " len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"visual.layer{b}\"))) for b in [1, 2, 3, 4]]\n", + " vision_layers = tuple(counts)\n", + " vision_width = state_dict[\"visual.layer1.0.conv1.weight\"].shape[0]\n", + " output_width = round((state_dict[\"visual.attnpool.positional_embedding\"].shape[0] - 1) ** 0.5)\n", + " vision_patch_size = None\n", + " assert output_width ** 2 + 1 == state_dict[\"visual.attnpool.positional_embedding\"].shape[0]\n", + " image_size = output_width * 32\n", + "\n", + " embed_dim = state_dict[\"text_projection\"].shape[1]\n", + " context_length = state_dict[\"positional_embedding\"].shape[0]\n", + " vocab_size = state_dict[\"token_embedding.weight\"].shape[0]\n", + " transformer_width = state_dict[\"ln_final.weight\"].shape[0]\n", + " transformer_heads = transformer_width // 64\n", + " transformer_layers = len(set(k.split(\".\")[2] for k in state_dict if k.startswith(f\"transformer.resblocks\")))\n", + "\n", + " vision_cfg = CLIPVisionCfg(\n", + " layers=vision_layers,\n", + " width=vision_width,\n", + " patch_size=vision_patch_size,\n", + " image_size=image_size,\n", + " )\n", + " text_cfg = CLIPTextCfg(\n", + " context_length=context_length,\n", + " vocab_size=vocab_size,\n", + " width=transformer_width,\n", + " heads=transformer_heads,\n", + " layers=transformer_layers,\n", + " )\n", + " model = CLIP(\n", + " embed_dim,\n", + " vision_cfg=vision_cfg,\n", + " text_cfg=text_cfg,\n", + " quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU\n", + " cast_dtype=cast_dtype,\n", + " )\n", + "\n", + " for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\n", + " state_dict.pop(key, None)\n", + " convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16\n", + " model.load_state_dict(state_dict)\n", + " return model.eval()\n", + "\n", + "\n", + "def trace_model(model, batch_size=256, device=torch.device('cpu')):\n", + " model.eval()\n", + " image_size = model.visual.image_size\n", + " example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)\n", + " example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)\n", + " model = torch.jit.trace_module(\n", + " model,\n", + " inputs=dict(\n", + " forward=(example_images, example_text),\n", + " encode_text=(example_text,),\n", + " encode_image=(example_images,)\n", + " ))\n", + " model.visual.image_size = image_size\n", + " return model\n", + "\n", + "\n", + "def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):\n", + " # Rescale the grid of position embeddings when loading from state_dict\n", + " old_pos_embed = state_dict.get('visual.positional_embedding', None)\n", + " if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):\n", + " return\n", + " grid_size = to_2tuple(model.visual.grid_size)\n", + " extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)\n", + " new_seq_len = grid_size[0] * grid_size[1] + extra_tokens\n", + " if new_seq_len == old_pos_embed.shape[0]:\n", + " return\n", + "\n", + " if extra_tokens:\n", + " pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]\n", + " else:\n", + " pos_emb_tok, pos_emb_img = None, old_pos_embed\n", + " old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))\n", + "\n", + " logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)\n", + " pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)\n", + " pos_emb_img = F.interpolate(\n", + " pos_emb_img,\n", + " size=grid_size,\n", + " mode=interpolation,\n", + " antialias=antialias,\n", + " align_corners=False,\n", + " )\n", + " pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]\n", + " if pos_emb_tok is not None:\n", + " new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)\n", + " else:\n", + " new_pos_embed = pos_emb_img\n", + " state_dict['visual.positional_embedding'] = new_pos_embed\n", + "\n", + "\n", + "def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False):\n", + " old_pos_embed = state_dict.get('positional_embedding', None)\n", + " if old_pos_embed is None:\n", + " return\n", + " # FIXME add support for text cls_token\n", + " model_pos_embed = getattr(model, 'positional_embedding', None)\n", + " if model_pos_embed is None:\n", + " model_pos_embed = getattr(model.text, 'positional_embedding', None)\n", + "\n", + " old_num_pos = old_pos_embed.shape[0]\n", + " old_width = old_pos_embed.shape[1]\n", + " num_pos = model_pos_embed.shape[0]\n", + " width = model_pos_embed.shape[1]\n", + " assert old_width == width, 'text pos_embed width changed!'\n", + " if old_num_pos == num_pos:\n", + " return\n", + "\n", + " logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos)\n", + " old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1)\n", + " old_pos_embed = F.interpolate(\n", + " old_pos_embed,\n", + " size=num_pos,\n", + " mode=interpolation,\n", + " antialias=antialias,\n", + " align_corners=False,\n", + " )\n", + " old_pos_embed = old_pos_embed.permute(0, 2, 1)[0]\n", + " new_pos_embed = old_pos_embed\n", + "\n", + " state_dict['positional_embedding'] = new_pos_embed\n", + "\n", + "\n", + "def get_model_preprocess_cfg(model):\n", + " module = getattr(model, 'visual', model)\n", + " preprocess_cfg = getattr(module, 'preprocess_cfg', {})\n", + " if not preprocess_cfg:\n", + " # use separate legacy attributes if preprocess_cfg dict not found\n", + " size = getattr(module, 'image_size')\n", + " if size is not None:\n", + " preprocess_cfg['size'] = size\n", + " mean = getattr(module, 'image_mean', None)\n", + " if mean is not None:\n", + " preprocess_cfg['mean'] = mean\n", + " std = getattr(module, 'image_std', None)\n", + " if std is not None:\n", + " preprocess_cfg['std'] = std\n", + " return preprocess_cfg\n", + "\n", + "\n", + "def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):\n", + " module = getattr(model, 'visual', model)\n", + " module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat\n", + " module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat\n", + " module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict\n", + "\n", + "\n", + "def get_model_tokenize_cfg(model):\n", + " module = getattr(model, 'text', model)\n", + " cfg = {}\n", + " context_length = getattr(module, 'context_length', None)\n", + " if context_length is not None:\n", + " cfg['context_length'] = context_length\n", + " vocab_size = getattr(module, 'vocab_size', None)\n", + " if vocab_size is not None:\n", + " cfg['vocab_size'] = vocab_size\n", + " return cfg\n", + "\n", + "\n", + "\n", + "try:\n", + " from huggingface_hub import hf_hub_download\n", + " hf_hub_download = partial(hf_hub_download, library_name=\"open_clip\", library_version=__version__)\n", + " _has_hf_hub = True\n", + "except ImportError:\n", + " hf_hub_download = None\n", + " _has_hf_hub = False\n", + "\n", + "\n", + "def _pcfg(url='', hf_hub='', **kwargs):\n", + " # OpenAI / OpenCLIP defaults\n", + " return {\n", + " 'url': url,\n", + " 'hf_hub': hf_hub,\n", + " 'mean': OPENAI_DATASET_MEAN,\n", + " 'std': OPENAI_DATASET_STD,\n", + " 'interpolation': 'bicubic',\n", + " 'resize_mode': 'shortest',\n", + " **kwargs,\n", + " }\n", + "\n", + "\n", + "def _slpcfg(url='', hf_hub='', **kwargs):\n", + " # SiGLIP defaults\n", + " return {\n", + " 'url': url,\n", + " 'hf_hub': hf_hub,\n", + " 'mean': INCEPTION_MEAN,\n", + " 'std': INCEPTION_STD,\n", + " 'interpolation': 'bicubic',\n", + " 'resize_mode': 'squash',\n", + " **kwargs,\n", + " }\n", + "\n", + "\n", + "def _apcfg(url='', hf_hub='', **kwargs):\n", + " # CLIPA defaults\n", + " return {\n", + " 'url': url,\n", + " 'hf_hub': hf_hub,\n", + " 'mean': IMAGENET_MEAN,\n", + " 'std': IMAGENET_STD,\n", + " 'interpolation': 'bilinear',\n", + " 'resize_mode': 'squash',\n", + " **kwargs,\n", + " }\n", + "\n", + "\n", + "def _mccfg(url='', hf_hub='', **kwargs):\n", + " # MobileCLIP\n", + " return {\n", + " 'url': url,\n", + " 'hf_hub': hf_hub,\n", + " 'mean': (0., 0., 0.),\n", + " 'std': (1., 1., 1.),\n", + " 'interpolation': 'bilinear',\n", + " 'resize_mode': 'shortest',\n", + " **kwargs,\n", + " }\n", + "\n", + "\n", + "\n", + "_RN50 = dict(\n", + " openai=_pcfg(\n", + " url=\"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt\",\n", + " hf_hub=\"timm/resnet50_clip.openai/\",\n", + " quick_gelu=True,\n", + " ),\n", + " yfcc15m=_pcfg(\n", + " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt\",\n", + " hf_hub=\"timm/resnet50_clip.yfcc15m/\",\n", + " quick_gelu=True,\n", + " ),\n", + " cc12m=_pcfg(\n", + " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt\",\n", + " hf_hub=\"timm/resnet50_clip.cc12m/\",\n", + " quick_gelu=True,\n", + " ),\n", + ")\n", + "\n", + "_RN101 = dict(\n", + " openai=_pcfg(\n", + " url=\"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt\",\n", + " hf_hub=\"timm/resnet101_clip.openai/\",\n", + " quick_gelu=True,\n", + " ),\n", + " yfcc15m=_pcfg(\n", + " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt\",\n", + " hf_hub=\"timm/resnet101_clip.yfcc15m/\",\n", + " quick_gelu=True,\n", + " ),\n", + ")\n", + "\n", + "_RN50x4 = dict(\n", + " openai=_pcfg(\n", + " url=\"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt\",\n", + " hf_hub=\"timm/resnet50x4_clip.openai/\",\n", + " quick_gelu=True,\n", + " ),\n", + ")\n", + "\n", + "_RN50x16 = dict(\n", + " openai=_pcfg(\n", + " url=\"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt\",\n", + " hf_hub=\"timm/resnet50x16_clip.openai/\",\n", + " quick_gelu=True,\n", + " ),\n", + ")\n", + "\n", + "_RN50x64 = dict(\n", + " openai=_pcfg(\n", + " url=\"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt\",\n", + " hf_hub=\"timm/resnet50x64_clip.openai/\",\n", + " quick_gelu=True,\n", + " ),\n", + ")\n", + "\n", + "_VITB32 = dict(\n", + " openai=_pcfg(\n", + " url=\"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt\",\n", + " hf_hub=\"timm/vit_base_patch32_clip_224.openai/\",\n", + " quick_gelu=True,\n", + " ),\n", + " # LAION 400M (quick gelu)\n", + " laion400m_e31=_pcfg(\n", + " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt\",\n", + " hf_hub=\"timm/vit_base_patch32_clip_224.laion400m_e31/\",\n", + " quick_gelu=True,\n", + " ),\n", + " laion400m_e32=_pcfg(\n", + " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt\",\n", + " hf_hub=\"timm/vit_base_patch32_clip_224.laion400m_e32/\",\n", + " quick_gelu=True,\n", + " ),\n", + " # LAION 2B-en\n", + " laion2b_e16=_pcfg(\n", + " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth\",\n", + " hf_hub=\"timm/vit_base_patch32_clip_224.laion2b_e16/\",\n", + " ),\n", + " laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'),\n", + " # DataComp-XL models\n", + " datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'),\n", + " # DataComp-M models\n", + " datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'),\n", + " commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'),\n", + " commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'),\n", + " commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'),\n", + " commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'),\n", + " commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'),\n", + " commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'),\n", + " # DataComp-S models\n", + " datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'),\n", + " commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'),\n", + " commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'),\n", + " commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'),\n", + " commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'),\n", + " commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'),\n", + " commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'),\n", + " # MetaClip models (NOTE quick-gelu activation used)\n", + " metaclip_400m=_pcfg(\n", + " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt\",\n", + " hf_hub=\"timm/vit_base_patch32_clip_224.metaclip_400m/\",\n", + " quick_gelu=True,\n", + " ),\n", + " metaclip_fullcc=_pcfg(\n", + " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt\",\n", + " hf_hub=\"timm/vit_base_patch32_clip_224.metaclip_2pt5b/\",\n", + " quick_gelu=True,\n", + " ),\n", + ")\n", + "\n", + "_VITB32_256 = dict(\n", + " datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'),\n", + ")\n", + "\n", + "_VITB16 = dict(\n", + " openai=_pcfg(\n", + " url=\"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt\",\n", + " hf_hub=\"timm/vit_base_patch16_clip_224.openai/\",\n", + " quick_gelu=True,\n", + " ),\n", + " # LAION-400M\n", + " laion400m_e31=_pcfg(\n", + " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt\",\n", + " hf_hub=\"timm/vit_base_patch16_clip_224.laion400m_e31/\",\n", + " ),\n", + " laion400m_e32=_pcfg(\n", + " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt\",\n", + " hf_hub=\"timm/vit_base_patch16_clip_224.laion400m_e32/\",\n", + " ),\n", + " # LAION-2B\n", + " laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),\n", + " # DataComp-XL models\n", + " datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'),\n", + " # DataComp-L models\n", + " datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'),\n", + " commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'),\n", + " commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'),\n", + " commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'),\n", + " commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'),\n", + " commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'),\n", + " commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'),\n", + " # DFN\n", + " dfn2b=_pcfg(\n", + " hf_hub='apple/DFN2B-CLIP-ViT-B-16/',\n", + " quick_gelu=True,\n", + " ),\n", + " # MetaCLIP (these are quick-gelu)\n", + " metaclip_400m=_pcfg(\n", + " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt\",\n", + " hf_hub=\"timm/vit_base_patch16_clip_224.metaclip_400m/\",\n", + " quick_gelu=True,\n", + " ),\n", + " metaclip_fullcc=_pcfg(\n", + " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt\",\n", + " hf_hub=\"timm/vit_base_patch16_clip_224.metaclip_2pt5b/\",\n", + " quick_gelu=True,\n", + " ),\n", + ")\n", + "\n", + "_VITB16_PLUS_240 = dict(\n", + " laion400m_e31=_pcfg(\n", + " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt\",\n", + " hf_hub=\"timm/vit_base_patch16_plus_clip_240.laion400m_e31/\",\n", + " ),\n", + " laion400m_e32=_pcfg(\n", + " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt\",\n", + " hf_hub=\"timm/vit_base_patch16_plus_clip_240.laion400m_e31/\",\n", + " ),\n", + ")\n", + "\n", + "_VITL14 = dict(\n", + " openai=_pcfg(\n", + " url=\"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt\",\n", + " hf_hub=\"timm/vit_large_patch14_clip_224.openai/\",\n", + " quick_gelu=True,\n", + " ),\n", + " # LAION-400M\n", + " laion400m_e31=_pcfg(\n", + " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt\",\n", + " hf_hub=\"timm/vit_large_patch14_clip_224.laion400m_e31/\",\n", + " ),\n", + " laion400m_e32=_pcfg(\n", + " url=\"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt\",\n", + " hf_hub=\"timm/vit_large_patch14_clip_224.laion400m_e32/\",\n", + " ),\n", + " # LAION-2B-en\n", + " laion2b_s32b_b82k=_pcfg(\n", + " hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',\n", + " mean=INCEPTION_MEAN, std=INCEPTION_STD),\n", + " # DataComp-XL models\n", + " datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'),\n", + " commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'),\n", + " commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'),\n", + " commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'),\n", + " # MetaCLIP\n", + " metaclip_400m=_pcfg(\n", + " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt\",\n", + " hf_hub=\"timm/vit_large_patch14_clip_224.metaclip_400m/\",\n", + " quick_gelu=True,\n", + " ),\n", + " metaclip_fullcc=_pcfg(\n", + " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt\",\n", + " hf_hub=\"timm/vit_large_patch14_clip_224.metaclip_2pt5b/\",\n", + " quick_gelu=True,\n", + " ),\n", + " # DFN-2B (quick-gelu)\n", + " dfn2b=_pcfg(\n", + " hf_hub='apple/DFN2B-CLIP-ViT-L-14/',\n", + " quick_gelu=True,\n", + " ),\n", + " # DFN-2B 39B SS\n", + " dfn2b_s39b=_pcfg(\n", + " hf_hub='apple/DFN2B-CLIP-ViT-L-14-39B/',\n", + " ),\n", + ")\n", + "\n", + "_VITL14_336 = dict(\n", + " openai=_pcfg(\n", + " url=\"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt\",\n", + " hf_hub=\"timm/vit_large_patch14_clip_336.openai/\",\n", + " quick_gelu=True,\n", + " ),\n", + ")\n", + "\n", + "_VITH14 = dict(\n", + " # LAION-2B-en\n", + " laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),\n", + " # MetaCLIP (quick-gelu)\n", + " metaclip_fullcc=_pcfg(\n", + " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt\",\n", + " hf_hub=\"timm/vit_huge_patch14_clip_224.metaclip_2pt5b/\",\n", + " quick_gelu=True,\n", + " ),\n", + " metaclip_altogether=_pcfg(\n", + " url=\"https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_v1.2_altogether.pt\",\n", + " hf_hub=\"timm/vit_huge_patch14_clip_224.metaclip_altogether/\",\n", + " # NOTE unlike other MetaCLIP models, this is not using QuickGELU, yay!\n", + " ),\n", + " # DFN-5B (quick-gelu)\n", + " dfn5b=_pcfg(\n", + " hf_hub='apple/DFN5B-CLIP-ViT-H-14/',\n", + " quick_gelu=True,\n", + " interpolation=\"bicubic\",\n", + " resize_mode=\"squash\"\n", + " ),\n", + ")\n", + "\n", + "_VITH14_378 = dict(\n", + " # DFN-5B (quick-gelu)\n", + " dfn5b=_pcfg(\n", + " hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/',\n", + " quick_gelu=True,\n", + " interpolation=\"bicubic\",\n", + " resize_mode=\"squash\"\n", + " ),\n", + ")\n", + "\n", + "_VITg14 = dict(\n", + " laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),\n", + " laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),\n", + ")\n", + "\n", + "_VITbigG14 = dict(\n", + " # LAION-2B-en\n", + " laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),\n", + " # MetaCLIP (quick-gelu)\n", + " metaclip_fullcc=_pcfg(\n", + " url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt',\n", + " hf_hub=\"timm/vit_gigantic_patch14_clip_224.metaclip_2pt5b/\",\n", + " quick_gelu=True,\n", + " ),\n", + ")\n", + "\n", + "_robertaViTB32 = dict(\n", + " laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'),\n", + ")\n", + "\n", + "_xlmRobertaBaseViTB32 = dict(\n", + " laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'),\n", + ")\n", + "\n", + "_xlmRobertaLargeFrozenViTH14 = dict(\n", + " frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'),\n", + ")\n", + "\n", + "_convnext_base = dict(\n", + " laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'),\n", + ")\n", + "\n", + "_convnext_base_w = dict(\n", + " laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'),\n", + " laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'),\n", + " laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'),\n", + ")\n", + "\n", + "_convnext_base_w_320 = dict(\n", + " laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'),\n", + " laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'),\n", + ")\n", + "\n", + "_convnext_large_d = dict(\n", + " laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'),\n", + ")\n", + "\n", + "_convnext_large_d_320 = dict(\n", + " laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'),\n", + " laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'),\n", + ")\n", + "\n", + "_convnext_xxlarge = dict(\n", + " laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'),\n", + " laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'),\n", + " laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'),\n", + ")\n", + "\n", + "_coca_VITB32 = dict(\n", + " laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'),\n", + " mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/')\n", + ")\n", + "\n", + "_coca_VITL14 = dict(\n", + " laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'),\n", + " mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/')\n", + ")\n", + "\n", + "\n", + "_PRETRAINED = {\n", + " \"RN50\": _RN50,\n", + " \"RN101\": _RN101,\n", + " \"RN50x4\": _RN50x4,\n", + " \"RN50x16\": _RN50x16,\n", + " \"RN50x64\": _RN50x64,\n", + "\n", + " \"ViT-B-32\": _VITB32,\n", + " \"ViT-B-32-256\": _VITB32_256,\n", + " \"ViT-B-16\": _VITB16,\n", + " \"ViT-B-16-plus-240\": _VITB16_PLUS_240,\n", + " \"ViT-L-14\": _VITL14,\n", + " \"ViT-L-14-336\": _VITL14_336,\n", + " \"ViT-H-14\": _VITH14,\n", + " \"ViT-H-14-378\": _VITH14_378,\n", + " \"ViT-g-14\": _VITg14,\n", + " \"ViT-bigG-14\": _VITbigG14,\n", + "\n", + " \"roberta-ViT-B-32\": _robertaViTB32,\n", + " \"xlm-roberta-base-ViT-B-32\": _xlmRobertaBaseViTB32,\n", + " \"xlm-roberta-large-ViT-H-14\": _xlmRobertaLargeFrozenViTH14,\n", + "\n", + " \"convnext_base\": _convnext_base,\n", + " \"convnext_base_w\": _convnext_base_w,\n", + " \"convnext_base_w_320\": _convnext_base_w_320,\n", + " \"convnext_large_d\": _convnext_large_d,\n", + " \"convnext_large_d_320\": _convnext_large_d_320,\n", + " \"convnext_xxlarge\": _convnext_xxlarge,\n", + "\n", + " \"coca_ViT-B-32\": _coca_VITB32,\n", + " \"coca_ViT-L-14\": _coca_VITL14,\n", + "\n", + " \"EVA01-g-14\": dict(\n", + " # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt\n", + " laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'),\n", + " ),\n", + " \"EVA01-g-14-plus\": dict(\n", + " # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt\n", + " merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'),\n", + " ),\n", + " \"EVA02-B-16\": dict(\n", + " # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt\n", + " merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'),\n", + " ),\n", + " \"EVA02-L-14\": dict(\n", + " # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt\n", + " merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'),\n", + " ),\n", + " \"EVA02-L-14-336\": dict(\n", + " # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt\n", + " merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'),\n", + " ),\n", + " \"EVA02-E-14\": dict(\n", + " # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt\n", + " laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'),\n", + " ),\n", + " \"EVA02-E-14-plus\": dict(\n", + " # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt\n", + " laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'),\n", + " ),\n", + "\n", + " \"ViT-B-16-SigLIP\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'),\n", + " ),\n", + " \"ViT-B-16-SigLIP-256\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'),\n", + " ),\n", + " \"ViT-B-16-SigLIP-i18n-256\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'),\n", + " ),\n", + " \"ViT-B-16-SigLIP-384\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'),\n", + " ),\n", + " \"ViT-B-16-SigLIP-512\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'),\n", + " ),\n", + " \"ViT-L-16-SigLIP-256\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'),\n", + " ),\n", + " \"ViT-L-16-SigLIP-384\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'),\n", + " ),\n", + " \"ViT-SO400M-14-SigLIP\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'),\n", + " ),\n", + " \"ViT-SO400M-16-SigLIP-i18n-256\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP-i18n-256/'),\n", + " ),\n", + " \"ViT-SO400M-14-SigLIP-378\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), # NOTE using 384 weights, but diff img_size used\n", + " ),\n", + " \"ViT-SO400M-14-SigLIP-384\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'),\n", + " ),\n", + "\n", + " \"ViT-B-32-SigLIP2-256\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-B-32-SigLIP2-256/'),\n", + " ),\n", + " \"ViT-B-16-SigLIP2\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2/'),\n", + " ),\n", + " \"ViT-B-16-SigLIP2-256\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-256/'),\n", + " ),\n", + " \"ViT-B-16-SigLIP2-384\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-384/'),\n", + " ),\n", + " \"ViT-B-16-SigLIP2-512\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP2-512/'),\n", + " ),\n", + " \"ViT-L-16-SigLIP2-256\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-256/'),\n", + " ),\n", + " \"ViT-L-16-SigLIP2-384\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-384/'),\n", + " ),\n", + " \"ViT-L-16-SigLIP2-512\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP2-512/'),\n", + " ),\n", + " \"ViT-SO400M-14-SigLIP2\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2/'),\n", + " ),\n", + " \"ViT-SO400M-14-SigLIP2-378\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP2-378/'),\n", + " ),\n", + " \"ViT-SO400M-16-SigLIP2-256\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-256/'),\n", + " ),\n", + " \"ViT-SO400M-16-SigLIP2-384\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-384/'),\n", + " ),\n", + " \"ViT-SO400M-16-SigLIP2-512\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-SO400M-16-SigLIP2-512/'),\n", + " ),\n", + " \"ViT-gopt-16-SigLIP2-256\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-256/'),\n", + " ),\n", + " \"ViT-gopt-16-SigLIP2-384\": dict(\n", + " webli=_slpcfg(hf_hub='timm/ViT-gopt-16-SigLIP2-384/'),\n", + " ),\n", + "\n", + " \"ViT-L-14-CLIPA\": dict(\n", + " datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'),\n", + " ),\n", + " \"ViT-L-14-CLIPA-336\": dict(\n", + " datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'),\n", + " ),\n", + " \"ViT-H-14-CLIPA\": dict(\n", + " datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'),\n", + " ),\n", + " \"ViT-H-14-CLIPA-336\": dict(\n", + " laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'),\n", + " datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'),\n", + " ),\n", + " \"ViT-bigG-14-CLIPA\": dict(\n", + " datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'),\n", + " ),\n", + " \"ViT-bigG-14-CLIPA-336\": dict(\n", + " datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'),\n", + " ),\n", + "\n", + " \"nllb-clip-base\": dict(\n", + " v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'),\n", + " ),\n", + " \"nllb-clip-large\": dict(\n", + " v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'),\n", + " ),\n", + "\n", + " \"nllb-clip-base-siglip\": dict(\n", + " v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'),\n", + " mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-base/'),\n", + " ),\n", + " \"nllb-clip-large-siglip\": dict(\n", + " v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'),\n", + " mrl=_slpcfg(hf_hub='visheratin/nllb-siglip-mrl-large/'),\n", + " ),\n", + "\n", + " \"MobileCLIP-S1\": dict(\n", + " datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S1-OpenCLIP/')),\n", + " \"MobileCLIP-S2\": dict(\n", + " datacompdr=_mccfg(hf_hub='apple/MobileCLIP-S2-OpenCLIP/')),\n", + " \"MobileCLIP-B\": dict(\n", + " datacompdr=_mccfg(hf_hub='apple/MobileCLIP-B-OpenCLIP/'),\n", + " datacompdr_lt=_mccfg(hf_hub='apple/MobileCLIP-B-LT-OpenCLIP/'),\n", + " ),\n", + "\n", + " \"ViTamin-S\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-S-LTT\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-S-LTT/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-B\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-B-LTT\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-B-LTT/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-L\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-224px/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-L-256\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-256px/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-L-336\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-336px/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-L-384\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L-384px/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-L2\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-224px/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-L2-256\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-256px/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-L2-336\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-336px/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-L2-384\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-L2-384px/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-XL-256\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-256px/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-XL-336\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-336px/pytorch_model.bin'),\n", + " ),\n", + " \"ViTamin-XL-384\": dict(\n", + " datacomp1b=_pcfg(hf_hub='jienengchen/ViTamin-XL-384px/pytorch_model.bin'),\n", + " ),\n", + "}\n", + "\n", + "_PRETRAINED_quickgelu = {}\n", + "for k, v in _PRETRAINED.items():\n", + " quick_gelu_tags = {}\n", + " for tk, tv in v.items():\n", + " if tv.get('quick_gelu', False):\n", + " quick_gelu_tags[tk] = copy.deepcopy(tv)\n", + " if quick_gelu_tags:\n", + " _PRETRAINED_quickgelu[k + '-quickgelu'] = quick_gelu_tags\n", + "_PRETRAINED.update(_PRETRAINED_quickgelu)\n", + "\n", + "def _clean_tag(tag: str):\n", + " # normalize pretrained tags\n", + " return tag.lower().replace('-', '_')\n", + "\n", + "\n", + "def list_pretrained(as_str: bool = False):\n", + " \"\"\" returns list of pretrained models\n", + " Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True\n", + " \"\"\"\n", + " return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]\n", + "\n", + "\n", + "def list_pretrained_models_by_tag(tag: str):\n", + " \"\"\" return all models having the specified pretrain tag \"\"\"\n", + " models = []\n", + " tag = _clean_tag(tag)\n", + " for k in _PRETRAINED.keys():\n", + " if tag in _PRETRAINED[k]:\n", + " models.append(k)\n", + " return models\n", + "\n", + "\n", + "def list_pretrained_tags_by_model(model: str):\n", + " \"\"\" return all pretrain tags for the specified model architecture \"\"\"\n", + " tags = []\n", + " if model in _PRETRAINED:\n", + " tags.extend(_PRETRAINED[model].keys())\n", + " return tags\n", + "\n", + "\n", + "def is_pretrained_cfg(model: str, tag: str):\n", + " if model not in _PRETRAINED:\n", + " return False\n", + " return _clean_tag(tag) in _PRETRAINED[model]\n", + "\n", + "\n", + "def get_pretrained_cfg(model: str, tag: str):\n", + " if model not in _PRETRAINED:\n", + " return {}\n", + " model_pretrained = _PRETRAINED[model]\n", + " return model_pretrained.get(_clean_tag(tag), {})\n", + "\n", + "\n", + "def get_pretrained_url(model: str, tag: str):\n", + " cfg = get_pretrained_cfg(model, _clean_tag(tag))\n", + " return cfg.get('url', '')\n", + "\n", + "\n", + "def download_pretrained_from_url(\n", + " url: str,\n", + " cache_dir: Optional[str] = None,\n", + "):\n", + " if not cache_dir:\n", + " cache_dir = os.path.expanduser(\"~/.cache/clip\")\n", + " os.makedirs(cache_dir, exist_ok=True)\n", + " filename = os.path.basename(url)\n", + "\n", + " if 'openaipublic' in url:\n", + " expected_sha256 = url.split(\"/\")[-2]\n", + " elif 'mlfoundations' in url:\n", + " expected_sha256 = os.path.splitext(filename)[0].split(\"-\")[-1]\n", + " else:\n", + " expected_sha256 = ''\n", + "\n", + " download_target = os.path.join(cache_dir, filename)\n", + "\n", + " if os.path.exists(download_target) and not os.path.isfile(download_target):\n", + " raise RuntimeError(f\"{download_target} exists and is not a regular file\")\n", + "\n", + " if os.path.isfile(download_target):\n", + " if expected_sha256:\n", + " if hashlib.sha256(open(download_target, \"rb\").read()).hexdigest().startswith(expected_sha256):\n", + " return download_target\n", + " else:\n", + " warnings.warn(f\"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file\")\n", + " else:\n", + " return download_target\n", + "\n", + " with urllib.request.urlopen(url) as source, open(download_target, \"wb\") as output:\n", + " with tqdm(total=int(source.headers.get(\"Content-Length\")), ncols=80, unit='iB', unit_scale=True) as loop:\n", + " while True:\n", + " buffer = source.read(8192)\n", + " if not buffer:\n", + " break\n", + "\n", + " output.write(buffer)\n", + " loop.update(len(buffer))\n", + "\n", + " if expected_sha256 and not hashlib.sha256(open(download_target, \"rb\").read()).hexdigest().startswith(expected_sha256):\n", + " raise RuntimeError(f\"Model has been downloaded but the SHA256 checksum does not not match\")\n", + "\n", + " return download_target\n", + "\n", + "\n", + "def has_hf_hub(necessary=False):\n", + " if not _has_hf_hub and necessary:\n", + " # if no HF Hub module installed, and it is necessary to continue, raise error\n", + " raise RuntimeError(\n", + " 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')\n", + " return _has_hf_hub\n", + "\n", + "\n", + "def _get_safe_alternatives(filename: str) -> Iterable[str]:\n", + " \"\"\"Returns potential safetensors alternatives for a given filename.\n", + "\n", + " Use case:\n", + " When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.\n", + " \"\"\"\n", + " if filename == HF_WEIGHTS_NAME:\n", + " yield HF_SAFE_WEIGHTS_NAME\n", + "\n", + " if filename not in (HF_WEIGHTS_NAME,) and (filename.endswith(\".bin\") or filename.endswith(\".pth\")):\n", + " yield filename[:-4] + \".safetensors\"\n", + "\n", + "\n", + "def download_pretrained_from_hf(\n", + " model_id: str,\n", + " filename: Optional[str] = None,\n", + " revision: Optional[str] = None,\n", + " cache_dir: Optional[str] = None,\n", + "):\n", + " has_hf_hub(True)\n", + "\n", + " filename = filename or HF_WEIGHTS_NAME\n", + "\n", + " # Look for .safetensors alternatives and load from it if it exists\n", + " if _has_safetensors:\n", + " for safe_filename in _get_safe_alternatives(filename):\n", + " try:\n", + " cached_file = hf_hub_download(\n", + " repo_id=model_id,\n", + " filename=safe_filename,\n", + " revision=revision,\n", + " cache_dir=cache_dir,\n", + " )\n", + " return cached_file\n", + " except Exception:\n", + " pass\n", + "\n", + " try:\n", + " # Attempt to download the file\n", + " cached_file = hf_hub_download(\n", + " repo_id=model_id,\n", + " filename=filename,\n", + " revision=revision,\n", + " cache_dir=cache_dir,\n", + " )\n", + " return cached_file # Return the path to the downloaded file if successful\n", + " except Exception as e:\n", + " raise FileNotFoundError(f\"Failed to download file ({filename}) for {model_id}. Last error: {e}\")\n", + "\n", + "\n", + "def download_pretrained(\n", + " cfg: Dict,\n", + " prefer_hf_hub: bool = True,\n", + " cache_dir: Optional[str] = None,\n", + "):\n", + " target = ''\n", + " if not cfg:\n", + " return target\n", + "\n", + " if 'file' in cfg:\n", + " return cfg['file']\n", + "\n", + " has_hub = has_hf_hub()\n", + " download_url = cfg.get('url', '')\n", + " download_hf_hub = cfg.get('hf_hub', '')\n", + " if has_hub and prefer_hf_hub and download_hf_hub:\n", + " # prefer to use HF hub, remove url info\n", + " download_url = ''\n", + "\n", + " if download_url:\n", + " target = download_pretrained_from_url(download_url, cache_dir=cache_dir)\n", + " elif download_hf_hub:\n", + " has_hf_hub(True)\n", + " # we assume the hf_hub entries in pretrained config combine model_id + filename in\n", + " # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and\n", + " # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.\n", + " model_id, filename = os.path.split(download_hf_hub)\n", + " if filename:\n", + " target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)\n", + " else:\n", + " target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)\n", + "\n", + " return target\n", + "\n", + "# ==================================================================\n", + "def merge_preprocess_dict(\n", + " base: Union[PreprocessCfg, Dict],\n", + " overlay: Dict,\n", + "):\n", + " \"\"\" Merge overlay key-value pairs on top of base preprocess cfg or dict.\n", + " Input dicts are filtered based on PreprocessCfg fields.\n", + " \"\"\"\n", + " if isinstance(base, PreprocessCfg):\n", + " base_clean = asdict(base)\n", + " else:\n", + " base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}\n", + " if overlay:\n", + " overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None}\n", + " base_clean.update(overlay_clean)\n", + " return base_clean\n", + "\n", + "\n", + "def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):\n", + " return merge_preprocess_dict(base, kwargs)\n", + "\n", + "\n", + "@dataclass\n", + "class PreprocessCfg:\n", + " size: Union[int, Tuple[int, int]] = 224\n", + " mode: str = 'RGB'\n", + " mean: Tuple[float, ...] = OPENAI_DATASET_MEAN\n", + " std: Tuple[float, ...] = OPENAI_DATASET_STD\n", + " interpolation: str = 'bicubic'\n", + " resize_mode: str = 'shortest'\n", + " fill_color: int = 0\n", + "\n", + " def __post_init__(self):\n", + " assert self.mode in ('RGB',)\n", + "\n", + " @property\n", + " def num_channels(self):\n", + " return 3\n", + "\n", + " @property\n", + " def input_size(self):\n", + " return (self.num_channels,) + to_2tuple(self.size)\n", + "\n", + "\n", + "\n", + "\n", + "@dataclass\n", + "class PreprocessCfg:\n", + " size: Union[int, Tuple[int, int]] = 224\n", + " mode: str = 'RGB'\n", + " mean: Tuple[float, ...] = OPENAI_DATASET_MEAN\n", + " std: Tuple[float, ...] = OPENAI_DATASET_STD\n", + " interpolation: str = 'bicubic'\n", + " resize_mode: str = 'shortest'\n", + " fill_color: int = 0\n", + "\n", + " def __post_init__(self):\n", + " assert self.mode in ('RGB',)\n", + "\n", + " @property\n", + " def num_channels(self):\n", + " return 3\n", + "\n", + " @property\n", + " def input_size(self):\n", + " return (self.num_channels,) + to_2tuple(self.size)\n", + "\n", + "_PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())\n", + "\n", + "\n", + "def merge_preprocess_dict(\n", + " base: Union[PreprocessCfg, Dict],\n", + " overlay: Dict,\n", + "):\n", + " \"\"\" Merge overlay key-value pairs on top of base preprocess cfg or dict.\n", + " Input dicts are filtered based on PreprocessCfg fields.\n", + " \"\"\"\n", + " if isinstance(base, PreprocessCfg):\n", + " base_clean = asdict(base)\n", + " else:\n", + " base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}\n", + " if overlay:\n", + " overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None}\n", + " base_clean.update(overlay_clean)\n", + " return base_clean\n", + "\n", + "\n", + "def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs):\n", + " return merge_preprocess_dict(base, kwargs)\n", + "\n", + "\n", + "@dataclass\n", + "class AugmentationCfg:\n", + " scale: Tuple[float, float] = (0.9, 1.0)\n", + " ratio: Optional[Tuple[float, float]] = None\n", + " color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None\n", + " re_prob: Optional[float] = None\n", + " re_count: Optional[int] = None\n", + " use_timm: bool = False\n", + "\n", + " # params for simclr_jitter_gray\n", + " color_jitter_prob: float = None\n", + " gray_scale_prob: float = None\n", + "\n", + "\n", + "def _setup_size(size, error_msg):\n", + " if isinstance(size, numbers.Number):\n", + " return int(size), int(size)\n", + "\n", + " if isinstance(size, Sequence) and len(size) == 1:\n", + " return size[0], size[0]\n", + "\n", + " if len(size) != 2:\n", + " raise ValueError(error_msg)\n", + "\n", + " return size\n", + "\n", + "\n", + "class ResizeKeepRatio:\n", + " \"\"\" Resize and Keep Ratio\n", + "\n", + " Copy & paste from `timm`\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " size,\n", + " longest=0.,\n", + " interpolation=InterpolationMode.BICUBIC,\n", + " random_scale_prob=0.,\n", + " random_scale_range=(0.85, 1.05),\n", + " random_aspect_prob=0.,\n", + " random_aspect_range=(0.9, 1.11)\n", + " ):\n", + " if isinstance(size, (list, tuple)):\n", + " self.size = tuple(size)\n", + " else:\n", + " self.size = (size, size)\n", + " self.interpolation = interpolation\n", + " self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest\n", + " self.random_scale_prob = random_scale_prob\n", + " self.random_scale_range = random_scale_range\n", + " self.random_aspect_prob = random_aspect_prob\n", + " self.random_aspect_range = random_aspect_range\n", + "\n", + " @staticmethod\n", + " def get_params(\n", + " img,\n", + " target_size,\n", + " longest,\n", + " random_scale_prob=0.,\n", + " random_scale_range=(0.85, 1.05),\n", + " random_aspect_prob=0.,\n", + " random_aspect_range=(0.9, 1.11)\n", + " ):\n", + " \"\"\"Get parameters\n", + " \"\"\"\n", + " source_size = img.size[::-1] # h, w\n", + " h, w = source_size\n", + " target_h, target_w = target_size\n", + " ratio_h = h / target_h\n", + " ratio_w = w / target_w\n", + " ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)\n", + " if random_scale_prob > 0 and random.random() < random_scale_prob:\n", + " ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])\n", + " ratio_factor = (ratio_factor, ratio_factor)\n", + " else:\n", + " ratio_factor = (1., 1.)\n", + " if random_aspect_prob > 0 and random.random() < random_aspect_prob:\n", + " aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])\n", + " ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)\n", + " size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]\n", + " return size\n", + "\n", + " def __call__(self, img):\n", + " \"\"\"\n", + " Args:\n", + " img (PIL Image): Image to be cropped and resized.\n", + "\n", + " Returns:\n", + " PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size\n", + " \"\"\"\n", + " size = self.get_params(\n", + " img, self.size, self.longest,\n", + " self.random_scale_prob, self.random_scale_range,\n", + " self.random_aspect_prob, self.random_aspect_range\n", + " )\n", + " img = F.resize(img, size, self.interpolation)\n", + " return img\n", + "\n", + " def __repr__(self):\n", + " format_string = self.__class__.__name__ + '(size={0}'.format(self.size)\n", + " format_string += f', interpolation={self.interpolation})'\n", + " format_string += f', longest={self.longest:.3f})'\n", + " return format_string\n", + "\n", + "\n", + "def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:\n", + " \"\"\"Center crops and/or pads the given image.\n", + " If the image is torch Tensor, it is expected\n", + " to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.\n", + " If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.\n", + "\n", + " Args:\n", + " img (PIL Image or Tensor): Image to be cropped.\n", + " output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,\n", + " it is used for both directions.\n", + " fill (int, Tuple[int]): Padding color\n", + "\n", + " Returns:\n", + " PIL Image or Tensor: Cropped image.\n", + " \"\"\"\n", + " if isinstance(output_size, numbers.Number):\n", + " output_size = (int(output_size), int(output_size))\n", + " elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:\n", + " output_size = (output_size[0], output_size[0])\n", + "\n", + " _, image_height, image_width = F.get_dimensions(img)\n", + " crop_height, crop_width = output_size\n", + "\n", + " if crop_width > image_width or crop_height > image_height:\n", + " padding_ltrb = [\n", + " (crop_width - image_width) // 2 if crop_width > image_width else 0,\n", + " (crop_height - image_height) // 2 if crop_height > image_height else 0,\n", + " (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,\n", + " (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,\n", + " ]\n", + " img = F.pad(img, padding_ltrb, fill=fill)\n", + " _, image_height, image_width = F.get_dimensions(img)\n", + " if crop_width == image_width and crop_height == image_height:\n", + " return img\n", + "\n", + " crop_top = int(round((image_height - crop_height) / 2.0))\n", + " crop_left = int(round((image_width - crop_width) / 2.0))\n", + " return F.crop(img, crop_top, crop_left, crop_height, crop_width)\n", + "\n", + "\n", + "class CenterCropOrPad(torch.nn.Module):\n", + " \"\"\"Crops the given image at the center.\n", + " If the image is torch Tensor, it is expected\n", + " to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.\n", + " If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.\n", + "\n", + " Args:\n", + " size (sequence or int): Desired output size of the crop. If size is an\n", + " int instead of sequence like (h, w), a square crop (size, size) is\n", + " made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).\n", + " \"\"\"\n", + "\n", + " def __init__(self, size, fill=0):\n", + " super().__init__()\n", + " self.size = _setup_size(size, error_msg=\"Please provide only two dimensions (h, w) for size.\")\n", + " self.fill = fill\n", + "\n", + " def forward(self, img):\n", + " \"\"\"\n", + " Args:\n", + " img (PIL Image or Tensor): Image to be cropped.\n", + "\n", + " Returns:\n", + " PIL Image or Tensor: Cropped image.\n", + " \"\"\"\n", + " return center_crop_or_pad(img, self.size, fill=self.fill)\n", + "\n", + " def __repr__(self) -> str:\n", + " return f\"{self.__class__.__name__}(size={self.size})\"\n", + "\n", + "\n", + "def _convert_to_rgb(image):\n", + " return image.convert('RGB')\n", + "\n", + "\n", + "class color_jitter(object):\n", + " \"\"\"\n", + " Apply Color Jitter to the PIL image with a specified probability.\n", + " \"\"\"\n", + " def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8):\n", + " assert 0. <= p <= 1.\n", + " self.p = p\n", + " self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)\n", + "\n", + " def __call__(self, img):\n", + " if random.random() < self.p:\n", + " return self.transf(img)\n", + " else:\n", + " return img\n", + "\n", + "\n", + "class gray_scale(object):\n", + " \"\"\"\n", + " Apply Gray Scale to the PIL image with a specified probability.\n", + " \"\"\"\n", + " def __init__(self, p=0.2):\n", + " assert 0. <= p <= 1.\n", + " self.p = p\n", + " self.transf = Grayscale(num_output_channels=3)\n", + "\n", + " def __call__(self, img):\n", + " if random.random() < self.p:\n", + " return self.transf(img)\n", + " else:\n", + " return img\n", + "\n", + "\n", + "def image_transform(\n", + " image_size: Union[int, Tuple[int, int]],\n", + " is_train: bool,\n", + " mean: Optional[Tuple[float, ...]] = None,\n", + " std: Optional[Tuple[float, ...]] = None,\n", + " resize_mode: Optional[str] = None,\n", + " interpolation: Optional[str] = None,\n", + " fill_color: int = 0,\n", + " aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,\n", + "):\n", + " mean = mean or OPENAI_DATASET_MEAN\n", + " if not isinstance(mean, (list, tuple)):\n", + " mean = (mean,) * 3\n", + "\n", + " std = std or OPENAI_DATASET_STD\n", + " if not isinstance(std, (list, tuple)):\n", + " std = (std,) * 3\n", + "\n", + " interpolation = interpolation or 'bicubic'\n", + " assert interpolation in ['bicubic', 'bilinear', 'random']\n", + " # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set\n", + " interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC\n", + "\n", + " resize_mode = resize_mode or 'shortest'\n", + " assert resize_mode in ('shortest', 'longest', 'squash')\n", + "\n", + " if isinstance(aug_cfg, dict):\n", + " aug_cfg = AugmentationCfg(**aug_cfg)\n", + " else:\n", + " aug_cfg = aug_cfg or AugmentationCfg()\n", + "\n", + " normalize = Normalize(mean=mean, std=std)\n", + "\n", + " if is_train:\n", + " aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}\n", + " use_timm = aug_cfg_dict.pop('use_timm', False)\n", + " if use_timm:\n", + " from timm.data import create_transform # timm can still be optional\n", + " if isinstance(image_size, (tuple, list)):\n", + " assert len(image_size) >= 2\n", + " input_size = (3,) + image_size[-2:]\n", + " else:\n", + " input_size = (3, image_size, image_size)\n", + "\n", + " aug_cfg_dict.setdefault('color_jitter', None) # disable by default\n", + " # drop extra non-timm items\n", + " aug_cfg_dict.pop('color_jitter_prob', None)\n", + " aug_cfg_dict.pop('gray_scale_prob', None)\n", + "\n", + " train_transform = create_transform(\n", + " input_size=input_size,\n", + " is_training=True,\n", + " hflip=0.,\n", + " mean=mean,\n", + " std=std,\n", + " re_mode='pixel',\n", + " interpolation=interpolation,\n", + " **aug_cfg_dict,\n", + " )\n", + " else:\n", + " train_transform = [\n", + " RandomResizedCrop(\n", + " image_size,\n", + " scale=aug_cfg_dict.pop('scale'),\n", + " interpolation=InterpolationMode.BICUBIC,\n", + " ),\n", + " _convert_to_rgb,\n", + " ]\n", + " if aug_cfg.color_jitter_prob:\n", + " assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4\n", + " train_transform.extend([\n", + " color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)\n", + " ])\n", + " if aug_cfg.gray_scale_prob:\n", + " train_transform.extend([\n", + " gray_scale(aug_cfg.gray_scale_prob)\n", + " ])\n", + " train_transform.extend([\n", + " ToTensor(),\n", + " normalize,\n", + " ])\n", + " train_transform = Compose(train_transform)\n", + " if aug_cfg_dict:\n", + " warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')\n", + " return train_transform\n", + " else:\n", + " if resize_mode == 'longest':\n", + " transforms = [\n", + " ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),\n", + " CenterCropOrPad(image_size, fill=fill_color)\n", + " ]\n", + " elif resize_mode == 'squash':\n", + " if isinstance(image_size, int):\n", + " image_size = (image_size, image_size)\n", + " transforms = [\n", + " Resize(image_size, interpolation=interpolation_mode),\n", + " ]\n", + " else:\n", + " assert resize_mode == 'shortest'\n", + " if not isinstance(image_size, (tuple, list)):\n", + " image_size = (image_size, image_size)\n", + " if image_size[0] == image_size[1]:\n", + " # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)\n", + " transforms = [\n", + " Resize(image_size[0], interpolation=interpolation_mode)\n", + " ]\n", + " else:\n", + " # resize shortest edge to matching target dim for non-square target\n", + " transforms = [ResizeKeepRatio(image_size)]\n", + " transforms += [CenterCrop(image_size)]\n", + "\n", + " transforms.extend([\n", + " _convert_to_rgb,\n", + " ToTensor(),\n", + " normalize,\n", + " ])\n", + " return Compose(transforms)\n", + " \n", + " \n", + "def image_transform_v2(\n", + " cfg: PreprocessCfg,\n", + " is_train: bool,\n", + " aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,\n", + "):\n", + " return image_transform(\n", + " image_size=cfg.size,\n", + " is_train=is_train,\n", + " mean=cfg.mean,\n", + " std=cfg.std,\n", + " interpolation=cfg.interpolation,\n", + " resize_mode=cfg.resize_mode,\n", + " fill_color=cfg.fill_color,\n", + " aug_cfg=aug_cfg,\n", + " )\n", + "\n", + "@dataclass\n", + "class AugmentationCfg:\n", + " scale: Tuple[float, float] = (0.9, 1.0)\n", + " ratio: Optional[Tuple[float, float]] = None\n", + " color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None\n", + " re_prob: Optional[float] = None\n", + " re_count: Optional[int] = None\n", + " use_timm: bool = False\n", + "\n", + " # params for simclr_jitter_gray\n", + " color_jitter_prob: float = None\n", + " gray_scale_prob: float = None\n", + "\n", + "def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]):\n", + " module = getattr(model, 'visual', model)\n", + " module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat\n", + " module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat\n", + " module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def convert_mobile_clip_state_dict(model: CustomTextCLIP, state_dict, fastvit = True):\n", + "\n", + " def _convert_timm_img(state_dict):\n", + " if fastvit:\n", + " from timm.models.fastvit import checkpoint_filter_fn\n", + " else:\n", + " from timm.models.vision_transformer_hybrid import checkpoint_filter_fn\n", + " timm_state_dict = checkpoint_filter_fn(state_dict, model.visual.trunk)\n", + " timm_state_dict = {'visual.trunk.' + k: v for k, v in timm_state_dict.items()}\n", + " return timm_state_dict\n", + "\n", + " def _convert_openclip_txt(state_dict, prefix='text_encoder.'):\n", + " text_dict = {}\n", + " for k, v in state_dict.items():\n", + " if not k.startswith(prefix):\n", + " continue\n", + " k = k.replace(prefix, '')\n", + " k = k.replace('projection_layer', 'text_projection')\n", + " k = k.replace('embedding_layer', 'token_embedding')\n", + " if k.startswith('positional_embedding.pos_embed.pos_embed'):\n", + " k = k.replace('positional_embedding.pos_embed.pos_embed', 'positional_embedding')\n", + " v = v.squeeze()\n", + " k = k.replace('final_layer_norm', 'ln_final')\n", + " k = k.replace('pre_norm_mha.0', 'ln_1')\n", + " k = k.replace('pre_norm_mha.1', 'attn')\n", + " k = k.replace('pre_norm_ffn.0', 'ln_2')\n", + " k = k.replace('pre_norm_ffn.1', 'mlp.c_fc')\n", + " k = k.replace('pre_norm_ffn.4', 'mlp.c_proj')\n", + " k = k.replace('qkv_proj.weight', 'in_proj_weight')\n", + " k = k.replace('qkv_proj.bias', 'in_proj_bias')\n", + " k = k.replace('transformer.', 'transformer.resblocks.')\n", + " text_dict['text.' + k] = v\n", + " return text_dict\n", + "\n", + " image_dict = _convert_timm_img(state_dict)\n", + " text_dict = _convert_openclip_txt(state_dict)\n", + " out_dict = {**image_dict, **text_dict}\n", + " out_dict['logit_scale'] = state_dict['logit_scale']\n", + " return out_dict\n", + "\n", + "\n", + "def convert_state_dict(model: Union[CustomTextCLIP, CLIP], state_dict):\n", + " if 'image_encoder.model.patch_embed.0.rbr_conv.0.conv.weight' in state_dict:\n", + " # Apple MobileCLIP s1 & s2 state_dicts (s0 and b not currently supported)\n", + " state_dict = convert_mobile_clip_state_dict(model, state_dict)\n", + " if 'image_encoder.model.patch_emb.0.block.conv.weight' in state_dict:\n", + " # convert b model\n", + " state_dict = convert_mobile_clip_state_dict(model, state_dict, fastvit=False)\n", + " return state_dict\n", + "\n", + "def load_state_dict(\n", + " checkpoint_path: str,\n", + " device='cpu',\n", + " weights_only=True,\n", + "):\n", + " # Check if safetensors or not and load weights accordingly\n", + " if str(checkpoint_path).endswith(\".safetensors\"):\n", + " from safetensors.torch import load_file\n", + " checkpoint = load_file(checkpoint_path, device=device)\n", + " else:\n", + " try:\n", + " checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)\n", + " except TypeError:\n", + " checkpoint = torch.load(checkpoint_path, map_location=device)\n", + "\n", + " if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:\n", + " state_dict = checkpoint['state_dict']\n", + " elif isinstance(checkpoint, torch.jit.ScriptModule):\n", + " state_dict = checkpoint.state_dict()\n", + " for key in [\"input_resolution\", \"context_length\", \"vocab_size\"]:\n", + " state_dict.pop(key, None)\n", + " else:\n", + " state_dict = checkpoint\n", + " if next(iter(state_dict.items()))[0].startswith('module'):\n", + " state_dict = {k[7:]: v for k, v in state_dict.items()}\n", + " return state_dict\n", + "\n", + "def load_checkpoint(\n", + " model: Union[CLIP, CustomTextCLIP],\n", + " checkpoint_path: str,\n", + " strict: bool = True,\n", + " weights_only: bool = True,\n", + " device='cpu',\n", + "):\n", + " if Path(checkpoint_path).suffix in ('.npz', '.npy'):\n", + " # Separate path loading numpy big_vision (SigLIP) weights\n", + " from open_clip.convert import load_big_vision_weights\n", + " load_big_vision_weights(model, checkpoint_path)\n", + " return {}\n", + "\n", + " state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only)\n", + "\n", + " # Detect & convert 3rd party state_dicts -> open_clip\n", + " state_dict = convert_state_dict(model, state_dict)\n", + "\n", + " # Detect old format and make compatible with new format\n", + " if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):\n", + " state_dict = convert_to_custom_text_state_dict(state_dict)\n", + "\n", + " # correct if logit_scale differs in being scaler vs 1d param\n", + " if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim:\n", + " state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape)\n", + "\n", + " # correct if logit_bias differs in being scaler vs 1d param\n", + " if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim:\n", + " state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape)\n", + "\n", + " # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712\n", + " if 'logit_bias' not in state_dict and model.logit_bias is not None:\n", + " state_dict[\"logit_bias\"] = torch.zeros_like(state_dict[\"logit_scale\"])\n", + "\n", + " # Certain text transformers no longer expect position_ids after transformers==4.31\n", + " position_id_key = 'text.transformer.embeddings.position_ids'\n", + " if position_id_key in state_dict and not hasattr(model, position_id_key):\n", + " del state_dict[position_id_key]\n", + "\n", + " resize_pos_embed(state_dict, model)\n", + " resize_text_pos_embed(state_dict, model)\n", + "\n", + " # Finally, load the massaged state_dict into model\n", + " incompatible_keys = model.load_state_dict(state_dict, strict=strict)\n", + " return incompatible_keys\n", + "\n", + "# /home/IITB/ai-at-ieor/23m1521/.conda/envs/openclip2/lib/python3.11/site-packages/open_clip/factory.py\n", + "HF_HUB_PREFIX = 'hf-hub:'\n", + "# _MODEL_CONFIG_PATHS = [Path(__file__).parent / f\"model_configs/\"]\n", + "_MODEL_CONFIG_PATHS = [Path(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/model_configs\")]\n", + "_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs\n", + "\n", + "import json\n", + "\n", + "def _get_hf_config(\n", + " model_id: str,\n", + " cache_dir: Optional[str] = None,\n", + "):\n", + " \"\"\" Fetch model config from HuggingFace Hub.\n", + " \"\"\"\n", + " config_path = download_pretrained_from_hf(\n", + " model_id,\n", + " filename='open_clip_config.json',\n", + " cache_dir=cache_dir,\n", + " )\n", + " with open(config_path, 'r', encoding='utf-8') as f:\n", + " config = json.load(f)\n", + " return config\n", + "\n", + "def get_model_config(model_name):\n", + " \"\"\" Fetch model config from builtin (local library) configs.\n", + " \"\"\"\n", + " if model_name in _MODEL_CONFIGS:\n", + " return copy.deepcopy(_MODEL_CONFIGS[model_name])\n", + " else:\n", + " return None\n", + "\n", + "def _natural_key(string_):\n", + " return [int(s) if s.isdigit() else s for s in re.split(r'(\\d+)', string_.lower())]\n", + "\n", + "\n", + "def _rescan_model_configs():\n", + " global _MODEL_CONFIGS\n", + "\n", + " config_ext = ('.json',)\n", + " config_files = []\n", + " for config_path in _MODEL_CONFIG_PATHS:\n", + " if config_path.is_file() and config_path.suffix in config_ext:\n", + " config_files.append(config_path)\n", + " elif config_path.is_dir():\n", + " for ext in config_ext:\n", + " config_files.extend(config_path.glob(f'*{ext}'))\n", + "\n", + " for cf in config_files:\n", + " with open(cf, 'r') as f:\n", + " model_cfg = json.load(f)\n", + " if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):\n", + " _MODEL_CONFIGS[cf.stem] = model_cfg\n", + "\n", + " _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}\n", + "\n", + "\n", + "_rescan_model_configs() # initial populate of model config registry\n", + "\n", + "def list_models():\n", + " \"\"\" enumerate available model architectures based on config files \"\"\"\n", + " return list(_MODEL_CONFIGS.keys())\n", + "\n", + "\n", + "def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs):\n", + " if past:\n", + " input_ids = input_ids[:, -1].unsqueeze(-1)\n", + "\n", + " attention_mask = kwargs.get(\"attention_mask\", None)\n", + " position_ids = kwargs.get(\"position_ids\", None)\n", + "\n", + " if attention_mask is not None and position_ids is None:\n", + " # create position_ids on the fly for batch generation\n", + " position_ids = attention_mask.long().cumsum(-1) - 1\n", + " position_ids.masked_fill_(attention_mask == 0, 1)\n", + " else:\n", + " position_ids = None\n", + " return {\n", + " \"text\": input_ids,\n", + " \"images\": image_inputs,\n", + " \"past_key_values\": past,\n", + " \"position_ids\": position_ids,\n", + " \"attention_mask\": attention_mask,\n", + " }\n", + "\n", + "@dataclass\n", + "class MultimodalCfg(CLIPTextCfg):\n", + " mlp_ratio: int = 4\n", + " dim_head: int = 64\n", + " heads: int = 8\n", + " n_queries: int = 256\n", + " attn_pooler_heads: int = 8\n", + "\n", + "try:\n", + " from transformers import (\n", + " BeamSearchScorer,\n", + " LogitsProcessorList,\n", + " TopPLogitsWarper,\n", + " TopKLogitsWarper,\n", + " RepetitionPenaltyLogitsProcessor,\n", + " MinLengthLogitsProcessor,\n", + " MaxLengthCriteria,\n", + " StopStringCriteria,\n", + " EosTokenCriteria,\n", + " StoppingCriteriaList\n", + " )\n", + "\n", + " GENERATION_TYPES = {\n", + " \"top_k\": TopKLogitsWarper,\n", + " \"top_p\": TopPLogitsWarper,\n", + " \"beam_search\": \"beam_search\"\n", + " }\n", + " _has_transformers = True\n", + "except ImportError as e:\n", + " GENERATION_TYPES = {\n", + " \"top_k\": None,\n", + " \"top_p\": None,\n", + " \"beam_search\": \"beam_search\"\n", + " }\n", + " _has_transformers = False\n", + "\n", + "def _token_to_tensor(token_id, device: str = \"cpu\") -> torch.Tensor:\n", + " if not isinstance(token_id, torch.Tensor):\n", + " if isinstance(token_id, int):\n", + " token_id = [token_id]\n", + " token_id = torch.tensor(token_id, device=device)\n", + " return token_id\n", + "\n", + "\n", + "def _build_text_decoder_tower(\n", + " embed_dim,\n", + " multimodal_cfg,\n", + " quick_gelu: bool = False,\n", + " cast_dtype: Optional[torch.dtype] = None,\n", + "):\n", + " multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg\n", + " act_layer = QuickGELU if quick_gelu else nn.GELU\n", + " norm_layer = (\n", + " LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm\n", + " )\n", + "\n", + " decoder = MultimodalTransformer(\n", + " context_length=multimodal_cfg.context_length,\n", + " width=multimodal_cfg.width,\n", + " heads=multimodal_cfg.heads,\n", + " layers=multimodal_cfg.layers,\n", + " ls_init_value=multimodal_cfg.ls_init_value,\n", + " output_dim=embed_dim,\n", + " act_layer=act_layer,\n", + " norm_layer=norm_layer,\n", + " )\n", + "\n", + " return decoder\n", + "\n", + "class CoCa(nn.Module):\n", + " def __init__(\n", + " self,\n", + " embed_dim,\n", + " multimodal_cfg: MultimodalCfg,\n", + " text_cfg: CLIPTextCfg,\n", + " vision_cfg: CLIPVisionCfg,\n", + " quick_gelu: bool = False,\n", + " init_logit_scale: float = np.log(1 / 0.07),\n", + " init_logit_bias: Optional[float] = None,\n", + " nonscalar_logit_scale: bool = False,\n", + " cast_dtype: Optional[torch.dtype] = None,\n", + " pad_id: int = 0,\n", + " ):\n", + " super().__init__()\n", + " multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg\n", + " text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg\n", + " vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg\n", + "\n", + " self.text = _build_text_tower(\n", + " embed_dim=embed_dim,\n", + " text_cfg=text_cfg,\n", + " quick_gelu=quick_gelu,\n", + " cast_dtype=cast_dtype,\n", + " )\n", + "\n", + " vocab_size = (\n", + " text_cfg.vocab_size # for hf models\n", + " if hasattr(text_cfg, \"hf_model_name\") and text_cfg.hf_model_name is not None\n", + " else text_cfg.vocab_size\n", + " )\n", + "\n", + " self.visual = _build_vision_tower(\n", + " embed_dim=embed_dim,\n", + " vision_cfg=vision_cfg,\n", + " quick_gelu=quick_gelu,\n", + " cast_dtype=cast_dtype,\n", + " )\n", + "\n", + " self.text_decoder = _build_text_decoder_tower(\n", + " vocab_size,\n", + " multimodal_cfg=multimodal_cfg,\n", + " quick_gelu=quick_gelu,\n", + " cast_dtype=cast_dtype,\n", + " )\n", + "\n", + " lshape = [1] if nonscalar_logit_scale else []\n", + " self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)\n", + " if init_logit_bias is not None:\n", + " self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)\n", + " else:\n", + " self.logit_bias = None\n", + " self.pad_id = pad_id\n", + "\n", + " self.context_length = multimodal_cfg.context_length\n", + "\n", + " @torch.jit.ignore\n", + " def set_grad_checkpointing(self, enable: bool = True):\n", + " self.visual.set_grad_checkpointing(enable)\n", + " self.text.set_grad_checkpointing(enable)\n", + " self.text_decoder.set_grad_checkpointing(enable)\n", + "\n", + " def _encode_image(self, images, normalize: bool = True):\n", + " image_latent, tokens_embs = self.visual(images)\n", + " image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent\n", + " return image_latent, tokens_embs\n", + "\n", + " def _encode_text(self, text, normalize: bool = True):\n", + " text_latent, token_emb = self.text(text)\n", + " text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent\n", + " return text_latent, token_emb\n", + "\n", + " def encode_image(self, images, normalize: bool = True):\n", + " image_latent, _ = self._encode_image(images, normalize=normalize)\n", + " return image_latent\n", + "\n", + " def encode_text(self, text, normalize: bool = True):\n", + " text_latent, _ = self._encode_text(text, normalize=normalize)\n", + " return text_latent\n", + "\n", + " def forward_intermediates(\n", + " self,\n", + " image: Optional[torch.Tensor] = None,\n", + " text: Optional[torch.Tensor] = None,\n", + " image_indices: Optional[Union[int, List[int]]] = None,\n", + " text_indices: Optional[Union[int, List[int]]] = None,\n", + " stop_early: bool = False,\n", + " normalize: bool = True,\n", + " normalize_intermediates: bool = False,\n", + " intermediates_only: bool = False,\n", + " image_output_fmt: str = 'NCHW',\n", + " image_output_extra_tokens: bool = False,\n", + " text_output_fmt: str = 'NLC',\n", + " text_output_extra_tokens: bool = False,\n", + " output_logits: bool = False,\n", + " output_logit_scale_bias: bool = False,\n", + " ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:\n", + " \"\"\" Forward features that returns intermediates.\n", + "\n", + " Args:\n", + " image: Input image tensor\n", + " text: Input text tensor\n", + " image_indices: For image tower, Take last n blocks if int, all if None, select matching indices if sequence\n", + " text_indices: Take last n blocks if int, all if None, select matching indices if sequence\n", + " stop_early: Stop iterating over blocks when last desired intermediate hit\n", + " normalize: L2 Normalize final image and text features (if present)\n", + " normalize_intermediates: Apply final encoder norm layer to all intermediates (if possible)\n", + " intermediates_only: Only return intermediate features, do not return final features\n", + " image_output_fmt: Shape of intermediate image feature outputs\n", + " image_output_extra_tokens: Return both prefix and spatial intermediate tokens\n", + " text_output_fmt: Shape of intermediate text feature outputs\n", + " text_output_extra_tokens: Return both prefix and spatial intermediate tokens\n", + " output_logits: Include logits in output\n", + " output_logit_scale_bias: Include the logit scale bias in the output\n", + " Returns:\n", + "\n", + " \"\"\"\n", + " output = {}\n", + " if intermediates_only:\n", + " # intermediates only disables final feature normalization, and include logits\n", + " normalize = False\n", + " output_logits = False\n", + " if output_logits:\n", + " assert False, 'FIXME, needs implementing'\n", + "\n", + " if image is not None:\n", + " image_output = self.visual.forward_intermediates(\n", + " image,\n", + " indices=image_indices,\n", + " stop_early=stop_early,\n", + " normalize_intermediates=normalize_intermediates,\n", + " intermediates_only=intermediates_only,\n", + " output_fmt=image_output_fmt,\n", + " output_extra_tokens=image_output_extra_tokens,\n", + " )\n", + " if normalize and \"image_features\" in image_output:\n", + " image_output[\"image_features\"] = F.normalize(image_output[\"image_features\"], dim=-1)\n", + " output.update(image_output)\n", + "\n", + " if text is not None:\n", + " text_output = self.text.forward_intermediates(\n", + " text,\n", + " indices=text_indices,\n", + " stop_early=stop_early,\n", + " normalize_intermediates=normalize_intermediates,\n", + " intermediates_only=intermediates_only,\n", + " output_fmt=text_output_fmt,\n", + " output_extra_tokens=text_output_extra_tokens,\n", + " )\n", + " if normalize and \"text_features\" in text_output:\n", + " text_output[\"text_features\"] = F.normalize(text_output[\"text_features\"], dim=-1)\n", + " output.update(text_output)\n", + "\n", + " # FIXME text decoder\n", + " logit_scale_exp = self.logit_scale.exp() if output_logits or output_logit_scale_bias else None\n", + " if output_logit_scale_bias:\n", + " output[\"logit_scale\"] = logit_scale_exp\n", + " if self.logit_bias is not None:\n", + " output['logit_bias'] = self.logit_bias\n", + "\n", + " return output\n", + "\n", + " def forward(\n", + " self,\n", + " image,\n", + " text: Optional[torch.Tensor] = None,\n", + " image_latent: Optional[torch.Tensor] = None,\n", + " image_embs: Optional[torch.Tensor] = None,\n", + " output_labels: bool = True,\n", + " ):\n", + " if image_latent is None or image_embs is None:\n", + " image_latent, image_embs = self._encode_image(image)\n", + "\n", + " if text is None:\n", + " return {\"image_features\": image_latent, \"image_embs\": image_embs}\n", + "\n", + " text_latent, token_embs = self._encode_text(text)\n", + "\n", + " # FIXME this isn't an ideal solution, would like to improve -RW\n", + " labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None\n", + " if output_labels:\n", + " # align text_embs and thus logits with labels for teacher-forcing caption loss\n", + " token_embs = token_embs[:, :-1]\n", + "\n", + " logits = self.text_decoder(image_embs, token_embs)\n", + " out_dict = {\n", + " \"image_features\": image_latent,\n", + " \"text_features\": text_latent,\n", + " \"logits\": logits,\n", + " \"logit_scale\": self.logit_scale.exp()\n", + " }\n", + " if labels is not None:\n", + " out_dict[\"labels\"] = labels\n", + " if self.logit_bias is not None:\n", + " out_dict[\"logit_bias\"] = self.logit_bias\n", + " return out_dict\n", + "\n", + " def generate(\n", + " self,\n", + " image,\n", + " text=None,\n", + " seq_len=30,\n", + " max_seq_len=77,\n", + " temperature=1.,\n", + " generation_type=\"beam_search\",\n", + " top_p=0.1, # keep tokens in the 1 - top_p quantile\n", + " top_k=1, # keeps the top_k most probable tokens\n", + " pad_token_id=None,\n", + " eos_token_id=None,\n", + " sot_token_id=None,\n", + " num_beams=6,\n", + " num_beam_groups=3,\n", + " min_seq_len=5,\n", + " stopping_criteria=None,\n", + " repetition_penalty=1.0,\n", + " fixed_output_length=False # if True output.shape == (batch_size, seq_len)\n", + " ):\n", + " # taking many ideas and components from HuggingFace GenerationMixin\n", + " # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation\n", + " assert _has_transformers, \"Please install transformers for generate functionality. `pip install transformers`.\"\n", + " assert seq_len > min_seq_len, \"seq_len must be larger than min_seq_len\"\n", + " device = image.device\n", + "\n", + " with torch.no_grad():\n", + " sot_token_id = _token_to_tensor(49406 if sot_token_id is None else sot_token_id, device=device)\n", + " eos_token_id = _token_to_tensor(49407 if eos_token_id is None else eos_token_id, device=device)\n", + " pad_token_id = self.pad_id if pad_token_id is None else pad_token_id\n", + " logit_processor = LogitsProcessorList(\n", + " [\n", + " MinLengthLogitsProcessor(min_seq_len, eos_token_id),\n", + " RepetitionPenaltyLogitsProcessor(repetition_penalty),\n", + " ]\n", + " )\n", + "\n", + " if stopping_criteria is None:\n", + " stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]\n", + " stopping_criteria = StoppingCriteriaList(stopping_criteria)\n", + "\n", + " if generation_type == \"beam_search\":\n", + " output = self._generate_beamsearch(\n", + " image_inputs=image,\n", + " pad_token_id=pad_token_id,\n", + " eos_token_id=eos_token_id,\n", + " sot_token_id=sot_token_id,\n", + " num_beams=num_beams,\n", + " num_beam_groups=num_beam_groups,\n", + " min_seq_len=min_seq_len,\n", + " stopping_criteria=stopping_criteria,\n", + " logit_processor=logit_processor,\n", + " )\n", + " if fixed_output_length and output.shape[1] < seq_len:\n", + " pad_len = seq_len - output.shape[1]\n", + " return torch.cat((\n", + " output,\n", + " torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id\n", + " ),\n", + " dim=1\n", + " )\n", + " return output\n", + "\n", + " elif generation_type == \"top_p\":\n", + " logit_warper = GENERATION_TYPES[generation_type](top_p)\n", + " elif generation_type == \"top_k\":\n", + " logit_warper = GENERATION_TYPES[generation_type](top_k)\n", + " else:\n", + " raise ValueError(\n", + " f\"generation_type has to be one of \"\n", + " f\"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}.\"\n", + " )\n", + "\n", + " image_latent, image_embs = self._encode_image(image)\n", + "\n", + " if text is None:\n", + " text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id\n", + "\n", + " was_training = self.training\n", + " num_dims = len(text.shape)\n", + "\n", + " if num_dims == 1:\n", + " text = text[None, :]\n", + "\n", + " self.eval()\n", + " out = text\n", + "\n", + " while True:\n", + " x = out[:, -max_seq_len:]\n", + " cur_len = x.shape[1]\n", + " logits = self(\n", + " image,\n", + " x,\n", + " image_latent=image_latent,\n", + " image_embs=image_embs,\n", + " output_labels=False,\n", + " )[\"logits\"][:, -1]\n", + " mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)\n", + " sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id\n", + "\n", + " if mask.all():\n", + " if not fixed_output_length:\n", + " break\n", + " else:\n", + " logits = logits[~mask, :]\n", + " filtered_logits = logit_processor(x[~mask, :], logits)\n", + " filtered_logits = logit_warper(x[~mask, :], filtered_logits)\n", + " probs = F.softmax(filtered_logits / temperature, dim=-1)\n", + "\n", + " if (cur_len + 1 == seq_len):\n", + " sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id\n", + " else:\n", + " sample[~mask, :] = torch.multinomial(probs, 1)\n", + "\n", + " out = torch.cat((out, sample), dim=-1)\n", + "\n", + " cur_len += 1\n", + "\n", + " if all(stopping_criteria(out, None)):\n", + " break\n", + "\n", + " if num_dims == 1:\n", + " out = out.squeeze(0)\n", + "\n", + " self.train(was_training)\n", + " return out\n", + "\n", + " def _generate_beamsearch(\n", + " self,\n", + " image_inputs,\n", + " pad_token_id=None,\n", + " eos_token_id=None,\n", + " sot_token_id=None,\n", + " num_beams=6,\n", + " num_beam_groups=3,\n", + " min_seq_len=5,\n", + " stopping_criteria=None,\n", + " logit_processor=None,\n", + " logit_warper=None,\n", + " ):\n", + " device = image_inputs.device\n", + " batch_size = image_inputs.shape[0]\n", + " image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0)\n", + " image_latent, image_embs = self._encode_image(image_inputs)\n", + "\n", + " input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)\n", + " input_ids = input_ids * sot_token_id\n", + " beam_scorer = BeamSearchScorer(\n", + " batch_size=batch_size,\n", + " num_beams=num_beams,\n", + " device=device,\n", + " num_beam_groups=num_beam_groups,\n", + " )\n", + " # instantiate logits processors\n", + " logits_processor = (\n", + " LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])\n", + " if logit_processor is None\n", + " else logit_processor\n", + " )\n", + "\n", + " num_beams = beam_scorer.num_beams\n", + " num_beam_groups = beam_scorer.num_beam_groups\n", + " num_sub_beams = num_beams // num_beam_groups\n", + " batch_size = len(beam_scorer._beam_hyps) // num_beam_groups\n", + " batch_beam_size, cur_len = input_ids.shape\n", + " beam_indices = None\n", + "\n", + " if num_beams * batch_size != batch_beam_size:\n", + " raise ValueError(\n", + " f\"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}.\"\n", + " )\n", + "\n", + " beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)\n", + " # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in\n", + " # the same group don't produce same tokens everytime.\n", + " beam_scores[:, ::num_sub_beams] = 0\n", + " beam_scores = beam_scores.view((batch_size * num_beams,))\n", + "\n", + " while True:\n", + "\n", + " # predicted tokens in cur_len step\n", + " current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)\n", + "\n", + " # indices which will form the beams in the next time step\n", + " reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)\n", + "\n", + " # do one decoder step on all beams of all sentences in batch\n", + " model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs)\n", + " outputs = self(\n", + " model_inputs['images'],\n", + " model_inputs['text'],\n", + " image_latent=image_latent,\n", + " image_embs=image_embs,\n", + " output_labels=False,\n", + " )\n", + "\n", + " for beam_group_idx in range(num_beam_groups):\n", + " group_start_idx = beam_group_idx * num_sub_beams\n", + " group_end_idx = min(group_start_idx + num_sub_beams, num_beams)\n", + " group_size = group_end_idx - group_start_idx\n", + "\n", + " # indices of beams of current group among all sentences in batch\n", + " batch_group_indices = []\n", + "\n", + " for batch_idx in range(batch_size):\n", + " batch_group_indices.extend(\n", + " [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]\n", + " )\n", + " group_input_ids = input_ids[batch_group_indices]\n", + "\n", + " # select outputs of beams of currentg group only\n", + " next_token_logits = outputs['logits'][batch_group_indices, -1, :]\n", + " vocab_size = next_token_logits.shape[-1]\n", + "\n", + " next_token_scores_processed = logits_processor(\n", + " group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx\n", + " )\n", + " next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)\n", + " next_token_scores = next_token_scores.expand_as(next_token_scores_processed)\n", + "\n", + " # reshape for beam search\n", + " next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)\n", + "\n", + " next_token_scores, next_tokens = torch.topk(\n", + " next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True\n", + " )\n", + "\n", + " next_indices = torch.div(next_tokens, vocab_size, rounding_mode=\"floor\")\n", + " next_tokens = next_tokens % vocab_size\n", + "\n", + " # stateless\n", + " process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None\n", + " beam_outputs = beam_scorer.process(\n", + " group_input_ids,\n", + " next_token_scores,\n", + " next_tokens,\n", + " next_indices,\n", + " pad_token_id=pad_token_id,\n", + " eos_token_id=eos_token_id,\n", + " beam_indices=process_beam_indices,\n", + " group_index=beam_group_idx,\n", + " )\n", + " beam_scores[batch_group_indices] = beam_outputs[\"next_beam_scores\"]\n", + " beam_next_tokens = beam_outputs[\"next_beam_tokens\"]\n", + " beam_idx = beam_outputs[\"next_beam_indices\"]\n", + "\n", + " input_ids[batch_group_indices] = group_input_ids[beam_idx]\n", + " group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)\n", + " current_tokens[batch_group_indices] = group_input_ids[:, -1]\n", + "\n", + " # (beam_idx // group_size) -> batch_idx\n", + " # (beam_idx % group_size) -> offset of idx inside the group\n", + " reordering_indices[batch_group_indices] = (\n", + " num_beams * torch.div(beam_idx, group_size, rounding_mode=\"floor\") + group_start_idx + (beam_idx % group_size)\n", + " )\n", + "\n", + " input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)\n", + "\n", + " # increase cur_len\n", + " cur_len = cur_len + 1\n", + " if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):\n", + " break\n", + "\n", + " final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None\n", + " sequence_outputs = beam_scorer.finalize(\n", + " input_ids,\n", + " beam_scores,\n", + " next_tokens,\n", + " next_indices,\n", + " pad_token_id=pad_token_id,\n", + " eos_token_id=eos_token_id,\n", + " max_length=stopping_criteria.max_length,\n", + " beam_indices=final_beam_indices,\n", + " )\n", + " return sequence_outputs['sequences']\n", + "\n", + "\n", + "def create_model(\n", + " model_name: str,\n", + " pretrained: Optional[str] = None,\n", + " precision: str = 'fp32',\n", + " device: Union[str, torch.device] = 'cpu',\n", + " jit: bool = False,\n", + " force_quick_gelu: bool = False,\n", + " force_custom_text: bool = False,\n", + " force_patch_dropout: Optional[float] = None,\n", + " force_image_size: Optional[Union[int, Tuple[int, int]]] = None,\n", + " force_preprocess_cfg: Optional[Dict[str, Any]] = None,\n", + " pretrained_image: bool = False,\n", + " pretrained_hf: bool = True,\n", + " cache_dir: Optional[str] = None,\n", + " output_dict: Optional[bool] = None,\n", + " require_pretrained: bool = False,\n", + " load_weights_only: bool = True,\n", + " **model_kwargs,\n", + "):\n", + " \"\"\"Creates and configures a contrastive vision-language model.\n", + "\n", + " Args:\n", + " model_name: Name of the model architecture to create. Can be a local model name\n", + " or a Hugging Face model ID prefixed with 'hf-hub:'.\n", + " pretrained: Tag/path for pretrained model weights. Can be:\n", + " - A pretrained tag name (e.g., 'openai')\n", + " - A path to local weights\n", + " - None to initialize with random weights\n", + " precision: Model precision/AMP configuration. Options:\n", + " - 'fp32': 32-bit floating point\n", + " - 'fp16'/'bf16': Mixed precision with FP32 for certain layers\n", + " - 'pure_fp16'/'pure_bf16': Pure 16-bit precision\n", + " device: Device to load the model on ('cpu', 'cuda', or torch.device object)\n", + " jit: If True, JIT compile the model\n", + " force_quick_gelu: Force use of QuickGELU activation\n", + " force_custom_text: Force use of custom text encoder\n", + " force_patch_dropout: Override default patch dropout value\n", + " force_image_size: Override default image size for vision encoder\n", + " force_preprocess_cfg: Override default preprocessing configuration\n", + " pretrained_image: Load pretrained weights for timm vision models\n", + " pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights\n", + " cache_dir: Override default cache directory for downloaded model files\n", + " output_dict: If True and model supports it, return dictionary of features\n", + " require_pretrained: Raise error if pretrained weights cannot be loaded\n", + " load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety)\n", + " **model_kwargs: Additional keyword arguments passed to model constructor\n", + "\n", + " Returns:\n", + " Created and configured model instance\n", + "\n", + " Raises:\n", + " RuntimeError: If model config is not found or required pretrained weights\n", + " cannot be loaded\n", + "\n", + " Examples:\n", + " # Create basic CLIP model\n", + " model = create_model('ViT-B/32')\n", + "\n", + " # Create CLIP model with mixed precision on GPU\n", + " model = create_model('ViT-B/32', precision='fp16', device='cuda')\n", + "\n", + " # Load pretrained OpenAI weights\n", + " model = create_model('ViT-B/32', pretrained='openai')\n", + "\n", + " # Load Hugging Face model\n", + " model = create_model('hf-hub:organization/model-name')\n", + " \"\"\"\n", + "\n", + " force_preprocess_cfg = force_preprocess_cfg or {}\n", + " preprocess_cfg = asdict(PreprocessCfg())\n", + " has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX)\n", + " if has_hf_hub_prefix:\n", + " model_id = model_name[len(HF_HUB_PREFIX):]\n", + " checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)\n", + " config = _get_hf_config(model_id, cache_dir=cache_dir)\n", + " preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg'])\n", + " model_cfg = config['model_cfg']\n", + " pretrained_hf = False # override, no need to load original HF text weights\n", + " else:\n", + " model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names\n", + " checkpoint_path = None\n", + " model_cfg = None\n", + "\n", + " if isinstance(device, str):\n", + " device = torch.device(device)\n", + "\n", + " model_cfg = model_cfg or get_model_config(model_name)\n", + " if model_cfg is not None:\n", + " logging.info(f'Loaded {model_name} model config.')\n", + " else:\n", + " logging.error(f'Model config for {model_name} not found; available models {list_models()}.')\n", + " raise RuntimeError(f'Model config for {model_name} not found.')\n", + "\n", + " if force_quick_gelu:\n", + " # override for use of QuickGELU on non-OpenAI transformer models\n", + " model_cfg[\"quick_gelu\"] = True\n", + "\n", + " if force_patch_dropout is not None:\n", + " # override the default patch dropout value\n", + " model_cfg[\"vision_cfg\"][\"patch_dropout\"] = force_patch_dropout\n", + "\n", + " if force_image_size is not None:\n", + " # override model config's image size\n", + " model_cfg[\"vision_cfg\"][\"image_size\"] = force_image_size\n", + "\n", + " is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {})\n", + " if pretrained_image:\n", + " if is_timm_model:\n", + " # pretrained weight loading for timm models set via vision_cfg\n", + " model_cfg['vision_cfg']['timm_model_pretrained'] = True\n", + " else:\n", + " assert False, 'pretrained image towers currently only supported for timm models'\n", + "\n", + " # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes\n", + " cast_dtype = get_cast_dtype(precision)\n", + " is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {})\n", + " if is_hf_model:\n", + " # load pretrained weights for HF text model IFF no CLIP weights being loaded\n", + " model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained\n", + " custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model\n", + "\n", + " model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg)\n", + " if custom_text:\n", + " if \"multimodal_cfg\" in model_cfg:\n", + " model = CoCa(**model_cfg, cast_dtype=cast_dtype)\n", + " else:\n", + " model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)\n", + " else:\n", + " model = CLIP(**model_cfg, cast_dtype=cast_dtype)\n", + "\n", + " if precision in (\"fp16\", \"bf16\"):\n", + " dtype = torch.float16 if 'fp16' in precision else torch.bfloat16\n", + " # manual mixed precision that matches original OpenAI behaviour\n", + " if is_timm_model:\n", + " # FIXME this is a bit janky, create timm based model in low-precision and\n", + " # then cast only LayerNormFp32 instances back to float32 so they don't break.\n", + " # Why? The convert_weights_to_lp fn only works with native models.\n", + " model.to(device=device, dtype=dtype)\n", + " # from .transformer import LayerNormFp32\n", + "\n", + " def _convert_ln(m):\n", + " if isinstance(m, LayerNormFp32):\n", + " m.weight.data = m.weight.data.to(torch.float32)\n", + " m.bias.data = m.bias.data.to(torch.float32)\n", + " model.apply(_convert_ln)\n", + " else:\n", + " model.to(device=device)\n", + " convert_weights_to_lp(model, dtype=dtype)\n", + " elif precision in (\"pure_fp16\", \"pure_bf16\"):\n", + " dtype = torch.float16 if 'fp16' in precision else torch.bfloat16\n", + " model.to(device=device, dtype=dtype)\n", + " else:\n", + " model.to(device=device)\n", + "\n", + " pretrained_loaded = False\n", + " if pretrained:\n", + " checkpoint_path = ''\n", + " pretrained_cfg = get_pretrained_cfg(model_name, pretrained)\n", + " if pretrained_cfg:\n", + " checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)\n", + " preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg)\n", + " pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False)\n", + " model_quick_gelu = model_cfg.get('quick_gelu', False)\n", + " if pretrained_quick_gelu and not model_quick_gelu:\n", + " warnings.warn(\n", + " f'These pretrained weights were trained with QuickGELU activation but the model config does '\n", + " f'not have that enabled. Consider using a model config with a \"-quickgelu\" suffix or enable with a flag.')\n", + " elif not pretrained_quick_gelu and model_quick_gelu:\n", + " warnings.warn(\n", + " f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the '\n", + " f'model config, consider using a model config without QuickGELU or disable override flags.')\n", + " elif os.path.exists(pretrained):\n", + " checkpoint_path = pretrained\n", + "\n", + " if checkpoint_path:\n", + " logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')\n", + " load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)\n", + " else:\n", + " error_str = (\n", + " f'Pretrained weights ({pretrained}) not found for model {model_name}.'\n", + " f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')\n", + " logging.warning(error_str)\n", + " raise RuntimeError(error_str)\n", + " pretrained_loaded = True\n", + " elif has_hf_hub_prefix:\n", + " logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).')\n", + " load_checkpoint(model, checkpoint_path, weights_only=load_weights_only)\n", + " pretrained_loaded = True\n", + "\n", + " if require_pretrained and not pretrained_loaded:\n", + " # callers of create_model_from_pretrained always expect pretrained weights\n", + " raise RuntimeError(\n", + " f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')\n", + "\n", + " if output_dict and hasattr(model, \"output_dict\"):\n", + " model.output_dict = True\n", + "\n", + " if jit:\n", + " model = torch.jit.script(model)\n", + "\n", + " # set image preprocessing configuration in model attributes for convenience\n", + " if getattr(model.visual, 'image_size', None) is not None:\n", + " # use image_size set on model creation (via config or force_image_size arg)\n", + " force_preprocess_cfg['size'] = model.visual.image_size\n", + " set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg))\n", + "\n", + " return model\n", + "\n", + "def create_model_and_transforms(\n", + " model_name: str,\n", + " pretrained: Optional[str] = None,\n", + " precision: str = 'fp32',\n", + " device: Union[str, torch.device] = 'cpu',\n", + " jit: bool = False,\n", + " force_quick_gelu: bool = False,\n", + " force_custom_text: bool = False,\n", + " force_patch_dropout: Optional[float] = None,\n", + " force_image_size: Optional[Union[int, Tuple[int, int]]] = None,\n", + " image_mean: Optional[Tuple[float, ...]] = None,\n", + " image_std: Optional[Tuple[float, ...]] = None,\n", + " image_interpolation: Optional[str] = None,\n", + " image_resize_mode: Optional[str] = None, # only effective for inference\n", + " aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,\n", + " pretrained_image: bool = False,\n", + " pretrained_hf: bool = True,\n", + " cache_dir: Optional[str] = None,\n", + " output_dict: Optional[bool] = None,\n", + " load_weights_only: bool = True,\n", + " **model_kwargs,\n", + "):\n", + " force_preprocess_cfg = merge_preprocess_kwargs(\n", + " {},\n", + " mean=image_mean,\n", + " std=image_std,\n", + " interpolation=image_interpolation,\n", + " resize_mode=image_resize_mode,\n", + " )\n", + "\n", + " model = create_model(\n", + " model_name,\n", + " pretrained,\n", + " precision=precision,\n", + " device=device,\n", + " jit=jit,\n", + " force_quick_gelu=force_quick_gelu,\n", + " force_custom_text=force_custom_text,\n", + " force_patch_dropout=force_patch_dropout,\n", + " force_image_size=force_image_size,\n", + " force_preprocess_cfg=force_preprocess_cfg,\n", + " pretrained_image=pretrained_image,\n", + " pretrained_hf=pretrained_hf,\n", + " cache_dir=cache_dir,\n", + " output_dict=output_dict,\n", + " load_weights_only=load_weights_only,\n", + " **model_kwargs,\n", + " )\n", + "\n", + " pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)\n", + "\n", + " preprocess_train = image_transform_v2(\n", + " pp_cfg,\n", + " is_train=True,\n", + " aug_cfg=aug_cfg,\n", + " )\n", + " preprocess_val = image_transform_v2(\n", + " pp_cfg,\n", + " is_train=False,\n", + " )\n", + "\n", + " return model, preprocess_train, preprocess_val\n", + "\n", + "\n", + "\n", + "open_clip_model, open_clip_imgaug, open_clip_preprocess = create_model_and_transforms(\n", + " model_name='ViT-H-14', pretrained='laion2b_s32b_b79k', device=device\n", + ")\n", + "# print(\"ashish 1\")\n", + "# exit()\n", + "\n", + "# ==================================================================\n", + "# C S I P - M O D U L E\n", + "# ==================================================================\n", + "class CSIP(nn.Module):\n", + " def __init__(self, image_encoder, audio_encoder, \n", + " dim_img=None, dim_audio=1024, dim_emb=1024):\n", + " super(CSIP, self).__init__()\n", + " \n", + " self.image_encoder = image_encoder # CLIPVisionModel\n", + " self.audio_encoder = audio_encoder # CLAP_audio_encoder\n", + "\n", + " for param in self.image_encoder.parameters():\n", + " param.requires_grad = False\n", + " \n", + " # self.image_proj = nn.Linear(dim_img, dim_emb)\n", + " self.audio_proj = nn.Linear(dim_audio, dim_emb)\n", + "\n", + " # Learnable temperature parameter\n", + " self.log_temp = nn.Parameter(torch.tensor(0.07).log())\n", + "\n", + " def forward(self, images, audios):\n", + " \n", + " # image_features = self.image_encoder(images) # shape: [n, dim_img]\n", + " image_features = images # shape: [n, dim_img]\n", + " audio_features = self.audio_encoder(audios)[0] # shape: [n, dim_audio]\n", + " \n", + " # Step 2: Project and normalize\n", + " image_embeds = F.normalize(image_features, dim=1) # [n, dim_emb]\n", + " audio_embeds = F.normalize(self.audio_proj(audio_features), dim=1) # [n, dim_emb]\n", + "\n", + " # Step 3: Cosine similarity with temperature\n", + " logits = torch.matmul(image_embeds, audio_embeds.T) * self.log_temp.exp() # [n, n]\n", + " probs = logits.softmax(dim=1)\n", + "\n", + " # Step 4: Symmetric cross-entropy loss\n", + " labels = torch.arange(len(images), device=images.device)\n", + " loss_i = F.cross_entropy(logits, labels)\n", + " loss_a = F.cross_entropy(logits.T, labels)\n", + " loss = (loss_i + loss_a) / 2\n", + " \n", + " # Step 5: Similarity metric (average cosine similarity on matched pairs)\n", + " similarity_scores = (image_embeds * audio_embeds).sum(dim=1) # Cosine similarity of matching pairs\n", + " avg_similarity = similarity_scores.mean()\n", + "\n", + " return loss, loss_i, loss_a, logits, probs, avg_similarity\n", + "\n", + "\n", + "# ==================================================================\n", + "# I M A G E - A U D I O - D A T A S E T\n", + "# ==================================================================\n", + "class VaaniImageAudioDataset(torch.utils.data.Dataset):\n", + " def __init__(self, df, image_features_savedir, audio_tensors_savedir):\n", + " self.image_paths = df.image_path.tolist()\n", + " self.audio_paths = df.audio_path.tolist()\n", + " self.image_features_savedir = image_features_savedir\n", + " self.audio_tensors_savedir = audio_tensors_savedir\n", + "\n", + " def __len__(self):\n", + " return len(self.audio_paths)\n", + "\n", + " def __getitem__(self, idx):\n", + " return {\n", + " 'image_path': self.image_paths[idx],\n", + " 'image_feature': torch.load(os.path.join(\n", + " self.image_features_savedir, \n", + " f\"{os.path.basename(self.image_paths[idx])}.pt\"))['image_features'],\n", + " 'audio_path': self.audio_paths[idx],\n", + " 'audio_tensor': torch.load(os.path.join(\n", + " audio_tensors_savedir, \n", + " f\"{os.path.basename(self.audio_paths[idx])}.pt\"))['audio_tensor']\n", + " }\n", + " \n", + "\n", + "train_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TRAIN3.csv\")\n", + "test_df = pd.read_csv(\"/home/IITB/ai-at-ieor/23m1521/ashish/MTP/Vaani/Img_Audio_Alignment/available_img_audios_TEST2.csv\")\n", + "image_features_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Image_features/'\n", + "audio_tensors_savedir = '/scratch/IITB/ai-at-ieor/23m1521/datasets/Vaani/Hindi_Audio_tensors/'\n", + "train_dataset = VaaniImageAudioDataset(train_df, image_features_savedir, audio_tensors_savedir)\n", + "test_dataset = VaaniImageAudioDataset(test_df, image_features_savedir, audio_tensors_savedir)\n", + "\n", + "print('Train Dataset:', len(train_dataset))\n", + "print('Test Dataset:', len(test_dataset))\n", + "\n", + "\n", + "BATCH_SIZE = int(128)\n", + "train_dataloader = torch.utils.data.DataLoader(\n", + " train_dataset,\n", + " batch_size=BATCH_SIZE, \n", + " shuffle=True, \n", + " num_workers=48,\n", + " pin_memory=True,\n", + " drop_last=False,\n", + " persistent_workers=True\n", + ")\n", + "\n", + "test_dataloader = torch.utils.data.DataLoader(\n", + " test_dataset,\n", + " batch_size=BATCH_SIZE, \n", + " shuffle=False, \n", + " num_workers=48,\n", + " pin_memory=True,\n", + " drop_last=False,\n", + " persistent_workers=True\n", + ")\n", + "\n", + "batch = next(iter(train_dataloader))\n", + "image_features_batch = batch['image_feature'].to(device=device)\n", + "audio_tensor_batch = batch['audio_tensor'].to(device=device)\n", + "image_paths_batch = batch['image_path']\n", + "audio_paths_batch = batch['audio_path']\n", + "print(\"Image batch shape:\", image_features_batch.shape) # [BATCH_SIZE, 3, 224, 224]\n", + "print(\"Audio batch shape:\", audio_tensor_batch.shape) # [BATCH_SIZE, 1, 44100]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4bbedbd2", + "metadata": {}, + "outputs": [], + "source": [ + "csip_model = CSIP(open_clip_model.visual, peft_clap_audio_encoder).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "ef5099c8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "=========================================================================================================================================================================================\n", + "Layer (type (var_name)) Trainable Param % Input Shape Output Shape Param #\n", + "=========================================================================================================================================================================================\n", + "CSIP (CSIP) Partial 94.70% [128, 1024] -- 632,076,801\n", + "├─LoraModel (audio_encoder) Partial -- [128, 308700] [128, 1024] --\n", + "│ └─AudioEncoder (model) Partial -- -- -- --\n", + "│ │ └─HTSATWrapper (base) Partial -- [128, 308700] [128, 768] --\n", + "│ │ │ └─HTSAT_Swin_Transformer (htsat) Partial 4.86% [128, 308700] [128, 768] 32,457,431\n", + "│ │ └─Projection (projection) True -- [128, 768] [128, 1024] --\n", + "│ │ │ └─Linear (linear1) True 0.12% [128, 768] [128, 1024] 815,104\n", + "│ │ │ └─Linear (linear2) True 0.16% [128, 1024] [128, 1024] 1,081,344\n", + "│ │ │ └─Dropout (drop) -- -- [128, 1024] [128, 1024] --\n", + "│ │ │ └─LayerNorm (layer_norm) True 0.00% [128, 1024] [128, 1024] 2,048\n", + "├─Linear (audio_proj) True 0.16% [128, 1024] [128, 1024] 1,049,600\n", + "=========================================================================================================================================================================================\n", + "Total params: 667,482,328\n", + "Trainable params: 4,175,127\n", + "Non-trainable params: 663,307,201\n", + "Total mult-adds (Units.GIGABYTES): 148.39\n", + "=========================================================================================================================================================================================\n", + "Input size (MB): 158.58\n", + "Forward/backward pass size (MB): 42578.49\n", + "Params size (MB): 140.34\n", + "Estimated Total Size (MB): 42877.41\n", + "=========================================================================================================================================================================================" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from torchinfo import summary\n", + "import subprocess\n", + "\n", + "for param in csip_model.audio_encoder.model.projection.parameters():\n", + " param.requires_grad = True\n", + " \n", + "summary(model=csip_model,\n", + " input_data=((image_features_batch.to(device)), (audio_tensor_batch.to(device))),\n", + " # input_size = (1, 3, config.IMAGE_SIZE, config.IMAGE_SIZE),\n", + " dtypes=[torch.long],\n", + " col_names = [\"trainable\", \"params_percent\", \"input_size\", \"output_size\", \"num_params\"],\n", + " col_width=20,\n", + " row_settings=[\"var_names\"],\n", + " depth = 4,\n", + " # verbose=2,\n", + " # device=device\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "clap", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}