""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause Based on timm code base https://github.com/rwightman/pytorch-image-models/tree/master/timm """ """ Copyright (c) 2023, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ """ Copyright (c) 2023, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ # Copyright (c) 2024 Black Forest Labs. # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 # # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20. # # Original file was released under Apache-2.0, with the full license text # available at https://github.com/black-forest-labs/flux/blob/main/LICENSE. # # This modified file is released under the same license. """ * Copyright (c) 2023, salesforce.com, inc. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause * By Junnan Li * Based on huggingface code base * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert """ from dataclasses import dataclass import torch from einops import rearrange from torch import Tensor, nn from safetensors.torch import load_file as load_sft import torch.nn as nn import torch # import math # from torchvision import transforms import os # from timm.models import create_model from typing import Any, Dict, List, Optional, Union from transformers import LlamaTokenizer # from torchvision.transforms.functional import pil_to_tensor # import torch from PIL import Image from torchvision import transforms import torch.utils.checkpoint as checkpoint DIFFUSION_NAME = 'stabilityai/stable-diffusion-2-1-unclip' #from diffusers import StableUnCLIPImg2ImgPipeline import logging import torch import torch.distributed as dist import torch.nn as nn from torch.cuda.amp import autocast as autocast from torch.nn import functional as F import numpy as np from functools import partial from einops import rearrange import contextlib import logging import os import time import datetime import torch import torch.nn as nn import torch.distributed as dist import torch.nn.functional as F from timm.models.layers import drop_path, to_2tuple, trunc_normal_ from transformers import BertTokenizer import math import torch import torch.nn as nn import torch.nn.functional as F from functools import partial from timm.models.vision_transformer import _cfg, PatchEmbed from timm.models.registry import register_model from timm.models.layers import trunc_normal_, DropPath from timm.models.helpers import named_apply, adapt_input_conv import math import os import warnings from dataclasses import dataclass from typing import Optional, Tuple, Dict, Any import torch from torch import Tensor, device, dtype, nn import torch.utils.checkpoint from torch.nn import CrossEntropyLoss import torch.nn.functional as F import numpy as np from transformers.activations import ACT2FN from transformers.file_utils import ( ModelOutput, ) from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput, NextSentencePredictorOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput, ) from transformers.modeling_utils import ( PreTrainedModel, apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer, ) from transformers.models.bert.configuration_bert import BertConfig @dataclass class AutoEncoderParams: resolution: int in_channels: int downsample: int ch: int out_ch: int ch_mult: list[int] num_res_blocks: int z_channels: int scale_factor: float shift_factor: float def swish(x: Tensor) -> Tensor: return x * torch.sigmoid(x) class AttnBlock(nn.Module): def __init__(self, in_channels: int): super().__init__() self.in_channels = in_channels self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) def attention(self, h_: Tensor) -> Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() h_ = nn.functional.scaled_dot_product_attention(q, k, v) return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) def forward(self, x: Tensor) -> Tensor: return x + self.proj_out(self.attention(x)) class ResnetBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h = x h = self.norm1(h) h = swish(h) h = self.conv1(h) h = self.norm2(h) h = swish(h) h = self.conv2(h) if self.in_channels != self.out_channels: x = self.nin_shortcut(x) return x + h class Downsample(nn.Module): def __init__(self, in_channels: int): super().__init__() # no asymmetric padding in torch conv, must do it ourselves self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x: Tensor): pad = (0, 1, 0, 1) x = nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) return x class Upsample(nn.Module): def __init__(self, in_channels: int): super().__init__() self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: Tensor): x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") x = self.conv(x) return x class Encoder(nn.Module): def __init__( self, resolution: int, in_channels: int, ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() block_in = self.ch for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: Tensor) -> Tensor: # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1]) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, ch: int, out_ch: int, ch_mult: list[int], num_res_blocks: int, in_channels: int, resolution: int, z_channels: int, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.ffactor = 2 ** (self.num_resolutions - 1) # compute in_ch_mult, block_in and curr_res at lowest res block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) # z to block_in self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks + 1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z: Tensor) -> Tensor: # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class DiagonalGaussian(nn.Module): def __init__(self, sample: bool = True, chunk_dim: int = 1): super().__init__() self.sample = sample self.chunk_dim = chunk_dim def forward(self, z: Tensor) -> Tensor: mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) if self.sample: std = torch.exp(0.5 * logvar) return mean + std * torch.randn_like(mean) else: return mean class AutoEncoder(nn.Module): def __init__(self, params: AutoEncoderParams): super().__init__() self.encoder = Encoder( resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, ) self.decoder = Decoder( resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, out_ch=params.out_ch, ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, ) self.reg = DiagonalGaussian() self.scale_factor = params.scale_factor self.shift_factor = params.shift_factor def encode(self, x: Tensor) -> Tensor: z = self.reg(self.encoder(x)) z = self.scale_factor * (z - self.shift_factor) return z def decode(self, z: Tensor) -> Tensor: z = z / self.scale_factor + self.shift_factor return self.decoder(z) def forward(self, x: Tensor) -> Tensor: return self.decode(self.encode(x)) def print_load_warning(missing: list[str], unexpected: list[str]) -> None: if len(missing) > 0 and len(unexpected) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) print("\n" + "-" * 79 + "\n") print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) elif len(missing) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) elif len(unexpected) > 0: print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) def load_ae(local_path: str) -> AutoEncoder: ae_params = AutoEncoderParams( resolution=256, in_channels=3, downsample=8, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ) # Loading the autoencoder ae = AutoEncoder(ae_params) if local_path is not None: sd = load_sft(local_path) missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return ae, ae_params #torch.set_printoptions(profile="full") class DropPathEvaVit(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPathEvaVit, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return 'p={}'.format(self.drop_prob) class MlpEvaVit(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) # x = self.drop(x) # commit this for the orignal BERT implement x = self.fc2(x) x = self.drop(x) return x class AttentionEvaVit(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=None, attn_head_dim=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads if attn_head_dim is not None: head_dim = attn_head_dim all_head_dim = head_dim * self.num_heads self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) else: self.q_bias = None self.v_bias = None if window_size: self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) else: self.window_size = None self.relative_position_bias_table = None self.relative_position_index = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x, rel_pos_bias=None): B, N, C = x.shape qkv_bias = None if self.q_bias is not None: qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) if self.relative_position_bias_table is not None: relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if rel_pos_bias is not None: attn = attn + rel_pos_bias attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, -1) x = self.proj(x) x = self.proj_drop(x) return x class BlockEvaVit(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, window_size=None, attn_head_dim=None): super().__init__() self.norm1 = norm_layer(dim) self.attn = AttentionEvaVit(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPathEvaVit(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = MlpEvaVit(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) if init_values is not None and init_values > 0: self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) else: self.gamma_1, self.gamma_2 = None, None def forward(self, x, rel_pos_bias=None): if self.gamma_1 is None: x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.mlp(self.norm2(x))) else: x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) return x class PatchEmbedEvaVit(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x, **kwargs): B, C, H, W = x.shape # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2) return x class RelativePositionBiasEvaVit(nn.Module): def __init__(self, window_size, num_heads): super().__init__() self.window_size = window_size self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH # cls to token & token 2 cls & cls to cls # get pair-wise relative position index for each token inside the window coords_h = torch.arange(window_size[0]) coords_w = torch.arange(window_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += window_size[1] - 1 relative_coords[:, :, 0] *= 2 * window_size[1] - 1 relative_position_index = \ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww relative_position_index[0, 0:] = self.num_relative_distance - 3 relative_position_index[0:, 0] = self.num_relative_distance - 2 relative_position_index[0, 0] = self.num_relative_distance - 1 self.register_buffer("relative_position_index", relative_position_index) # trunc_normal_(self.relative_position_bias_table, std=.02) def forward(self): relative_position_bias = \ self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww class VisionTransformerEvaVit(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, use_mean_pooling=True, init_scale=0.001, use_checkpoint=False): super().__init__() self.image_size = img_size self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.patch_embed = PatchEmbedEvaVit(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if use_abs_pos_emb: self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) else: self.pos_embed = None self.pos_drop = nn.Dropout(p=drop_rate) if use_shared_rel_pos_bias: self.rel_pos_bias = RelativePositionBiasEvaVit(window_size=self.patch_embed.patch_shape, num_heads=num_heads) else: self.rel_pos_bias = None self.use_checkpoint = use_checkpoint dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.use_rel_pos_bias = use_rel_pos_bias self.blocks = nn.ModuleList([ BlockEvaVit(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth) ]) # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() if self.pos_embed is not None: trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) # trunc_normal_(self.mask_token, std=.02) # if isinstance(self.head, nn.Linear): # trunc_normal_(self.head.weight, std=.02) self.apply(self._init_weights) self.fix_init_weight() self.ln_vision = nn.LayerNorm(self.num_features) def fix_init_weight(self): def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) for layer_id, layer in enumerate(self.blocks): rescale(layer.attn.proj.weight.data, layer_id + 1) rescale(layer.mlp.fc2.weight.data, layer_id + 1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) _initialize_weights = _init_weights def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) batch_size, seq_len, _ = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) if self.pos_embed is not None: x = x + self.pos_embed x = self.pos_drop(x) rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None for blk in self.blocks: if self.use_checkpoint: x = checkpoint.checkpoint(blk, x, rel_pos_bias) else: x = blk(x, rel_pos_bias) return x def forward(self, x): x = self.forward_features(x) # x = self.head(x) return x def get_intermediate_layers(self, x): x = self.patch_embed(x) batch_size, seq_len, _ = x.size() cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) if self.pos_embed is not None: x = x + self.pos_embed x = self.pos_drop(x) features = [] rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None for blk in self.blocks: x = blk(x, rel_pos_bias) features.append(x) return features def get_num_layer(self, var_name=""): if var_name in ("cls_token", "mask_token", "pos_embed"): return 0 elif var_name.startswith("patch_embed"): return 0 elif var_name.startswith("rel_pos_bias"): return len(self.blocks) - 1 elif var_name.startswith("blocks"): layer_id = int(var_name.split('.')[1]) return layer_id + 1 else: return len(self.blocks) def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16", cache_dir="./",): model = VisionTransformerEvaVit( img_size=img_size, patch_size=14, use_mean_pooling=False, embed_dim=1408, depth=39, num_heads=1408 // 88, mlp_ratio=4.3637, qkv_bias=True, drop_path_rate=drop_path_rate, norm_layer=partial(nn.LayerNorm, eps=1e-6), use_checkpoint=use_checkpoint, ) cache_path = cache_dir state_dict = torch.load(cache_path+"/eva_vit_g.pth", map_location="cpu") interpolate_pos_embed(model, state_dict) incompatible_keys = model.load_state_dict(state_dict, strict=False) #print(incompatible_keys) return model class BertEmbeddings(nn.Module): """Construct the embeddings from word and position embeddings.""" def __init__(self, config): super().__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.config = config def forward( self, input_ids=None, position_ids=None, query_embeds=None, past_key_values_length=0, ): if input_ids is not None: seq_length = input_ids.size()[1] else: seq_length = 0 if position_ids is None: position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length].clone() if input_ids is not None: embeddings = self.word_embeddings(input_ids) if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings if query_embeds is not None: embeddings = torch.cat((query_embeds, embeddings), dim=1) #print(query_embeds.shape, embeddings.shape) else: embeddings = query_embeds embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class BertSelfAttention(nn.Module): def __init__(self, config, is_cross_attention): super().__init__() self.config = config if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) if is_cross_attention: self.key = nn.Linear(config.encoder_width, self.all_head_size) self.value = nn.Linear(config.encoder_width, self.all_head_size) else: self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") if (self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query"): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.save_attention = False def save_attn_gradients(self, attn_gradients): self.attn_gradients = attn_gradients def get_attn_gradients(self): return self.attn_gradients def save_attention_map(self, attention_map): self.attention_map = attention_map def get_attention_map(self): return self.attention_map def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None if is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) #print(key_layer.shape, value_layer.shape) attention_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) #print(past_key_value[0].shape, key_layer.shape) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) mixed_query_layer = self.query(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) # if past_key_value is not None: # print(query_layer.shape) past_key_value = (key_layer, value_layer) #print(key_layer.shape, value_layer.shape) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) #if is_cross_attention: # if attention_scores.shape[2] == 32: # attention_scores_save = attention_scores[0].detach().cpu().numpy() # print(attention_scores_save.shape) # np.save('attention_scores_causal_text_child.npy', attention_scores_save) if (self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query"): seq_length = hidden_states.size()[1] position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) attention_scores = (attention_scores + relative_position_scores_query + relative_position_scores_key) attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) if is_cross_attention and self.save_attention: self.save_attention_map(attention_probs) attention_probs.register_hook(self.save_attn_gradients) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs_dropped = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs_dropped = attention_probs_dropped * head_mask context_layer = torch.matmul(attention_probs_dropped, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, ) context_layer = context_layer.view(*new_context_layer_shape) outputs = ((context_layer, attention_probs) if output_attentions else (context_layer, )) outputs = outputs + (past_key_value, ) return outputs class BertSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertAttention(nn.Module): def __init__(self, config, is_cross_attention=False): super().__init__() self.self = BertSelfAttention(config, is_cross_attention) self.output = BertSelfOutput(config) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads, ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = (self.self.attention_head_size * self.self.num_attention_heads) self.pruned_heads = self.pruned_heads.union(heads) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, ): self_outputs = self.self( hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output, ) + self_outputs[1:] # add attentions if we output them return outputs class BertIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class BertOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertLayer(nn.Module): def __init__(self, config, layer_num): super().__init__() self.config = config self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = BertAttention(config) self.layer_num = layer_num if (self.config.add_cross_attention and layer_num % self.config.cross_attention_freq == 0): self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) self.has_cross_attention = True else: self.has_cross_attention = False self.intermediate = BertIntermediate(config) self.output = BertOutput(config) self.intermediate_query = BertIntermediate(config) self.output_query = BertOutput(config) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, query_length=0, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = (past_key_value[:2] if past_key_value is not None else None) # if past_key_value is not None: # print(hidden_states.shape, attention_mask.shape) #print(hidden_states.shape, attention_mask.shape) # casual attention for query embeds with self attention self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, ) #print('attention_mask', attention_mask.shape) # if attention_mask.shape[-1] == 77: # print('attention_mask', attention_mask[0]) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] #print(present_key_value[0].shape) if query_length > 0: query_attention_output = attention_output[:, :query_length, :] if self.has_cross_attention: assert (encoder_hidden_states is not None), "encoder_hidden_states must be given for cross-attention layers" #print(attention_mask.shape) cross_attention_outputs = self.crossattention( query_attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions=output_attentions, ) query_attention_output = cross_attention_outputs[0] outputs = (outputs + cross_attention_outputs[1:-1]) # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk_query, self.chunk_size_feed_forward, self.seq_len_dim, query_attention_output, ) if attention_output.shape[1] > query_length: layer_output_text = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output[:, query_length:, :], ) layer_output = torch.cat([layer_output, layer_output_text], dim=1) else: layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output, ) outputs = (layer_output, ) + outputs outputs = outputs + (present_key_value, ) return outputs def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output def feed_forward_chunk_query(self, attention_output): intermediate_output = self.intermediate_query(attention_output) layer_output = self.output_query(intermediate_output, attention_output) return layer_output class BertEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.layer = nn.ModuleList([BertLayer(config, i) for i in range(config.num_hidden_layers)]) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, query_length=0, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = (() if output_attentions and self.config.add_cross_attention else None) next_decoder_cache = () if use_cache else None for i in range(self.config.num_hidden_layers): layer_module = self.layer[i] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None # if past_key_value is not None: # print(past_key_value[0].shape, past_key_value[1].shape) if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: logger.warn("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions, query_length) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, query_length, ) # if past_key_value is not None: # print(hidden_states.shape, attention_mask.shape) # print(len(past_key_value)) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[-1], ) #print(layer_outputs[-1][0].shape) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1], ) all_cross_attentions = all_cross_attentions + (layer_outputs[2], ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states, ) if not return_dict: return tuple(v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) class BertPooler(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) if isinstance(config.hidden_act, str): self.transform_act_fn = ACT2FN[config.hidden_act] else: self.transform_act_fn = config.hidden_act self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertLMPredictionHead(nn.Module): def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states class BertOnlyMLMHead(nn.Module): def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) def forward(self, sequence_output): prediction_scores = self.predictions(sequence_output) return prediction_scores class BertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = BertConfig base_model_prefix = "bert" _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() class BertModel(BertPreTrainedModel): """ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between the self-attention layers, following the architecture described in `Attention is all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an input to the forward pass. """ def __init__(self, config, add_pooling_layer=False): super().__init__(config) self.config = config self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None self.init_weights() def get_input_embeddings(self): return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def get_extended_attention_mask( self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool, is_casual: bool, has_query: bool = False, ) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Arguments: attention_mask (:obj:`torch.Tensor`): Mask with ones indicating tokens to attend to, zeros for tokens to ignore. input_shape (:obj:`Tuple[int]`): The shape of the input to the model. device: (:obj:`torch.device`): The device of the input to the model. Returns: :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. """ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. #print(attention_mask.dim()) if attention_mask.dim() == 3: extended_attention_mask = attention_mask[:, None, :, :] elif attention_mask.dim() == 2: # Provided a padding mask of dimensions [batch_size, seq_length] # - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] if is_decoder or is_casual: batch_size, seq_length = input_shape #print(input_shape) if not is_decoder and seq_length > 32: query_length = 32 text_length = seq_length - query_length query_ids = torch.arange(query_length, device=device) query_causal_mask = (query_ids[None, None, :].repeat(batch_size, query_length, 1) <= query_ids[None, :, None]) causal_mask = torch.ones((batch_size, seq_length, seq_length), device=device) causal_mask[:, :query_length, :query_length] = query_causal_mask # print(query_causal_mask.shape, causal_mask.shape) #print(causal_mask[0]) else: seq_ids = torch.arange(seq_length, device=device) causal_mask = (seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]) # add a prefix ones mask to the causal mask # causal and attention masks must have same type with pytorch version < 1.3 causal_mask = causal_mask.to(attention_mask.dtype) # if is_decoder: # print(causal_mask.shape, attention_mask.shape) #print(causal_mask.shape, attention_mask.shape) if causal_mask.shape[1] < attention_mask.shape[1]: prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] if has_query: # UniLM style attention mask causal_mask = torch.cat( [ torch.zeros( (batch_size, prefix_seq_len, seq_length), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=1, ) causal_mask = torch.cat( [ torch.ones( (batch_size, causal_mask.shape[1], prefix_seq_len), device=device, dtype=causal_mask.dtype, ), causal_mask, ], axis=-1, ) #print(has_query, causal_mask.shape) #print(causal_mask[0]) extended_attention_mask = (causal_mask[:, None, :, :] * attention_mask[:, None, None, :]) #print(extended_attention_mask[0]) #print('extended_attention_mask', extended_attention_mask.shape) else: extended_attention_mask = attention_mask[:, None, None, :] #print(attention_mask.shape, extended_attention_mask.shape) else: raise ValueError("Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( input_shape, attention_mask.shape)) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, is_decoder=False, ): r""" encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). """ output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions) output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) # use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is None: assert (query_embeds is not None), "You have to specify query_embeds when input_ids is None" #if query_embeds is not None: if query_embeds is not None and query_embeds.shape[1] == 32: is_casual = True else: is_casual = False past_key_values_length = (past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0) query_length = query_embeds.shape[1] if query_embeds is not None else 0 embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, query_embeds=query_embeds, past_key_values_length=past_key_values_length, ) input_shape = embedding_output.size()[:-1] batch_size, seq_length = input_shape device = embedding_output.device #print('attention_mask', attention_mask) if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) #print(seq_length, past_key_values_length) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if is_decoder: #print(attention_mask.shape, input_ids.shape) extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_ids.shape, device, is_decoder, is_casual, has_query=(query_embeds is not None), ) else: extended_attention_mask = self.get_extended_attention_mask( attention_mask, input_shape, device, is_decoder, is_casual, ) #print(is_decoder, extended_attention_mask.shape) # if is_decoder: # print(extended_attention_mask[0,0,:,32:]) # if attention_mask is not None: # print(input_ids, embedding_output.shape, extended_attention_mask.shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: if type(encoder_hidden_states) == list: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() else: ( encoder_batch_size, encoder_sequence_length, _, ) = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if type(encoder_attention_mask) == list: encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] elif encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) #print(is_casual, extended_attention_mask.shape, encoder_attention_mask.shape, encoder_extended_attention_mask.shape) else: encoder_extended_attention_mask = None # if input_ids is not None and query_embeds is not None: # print(extended_attention_mask.shape, encoder_extended_attention_mask.shape) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) #print(head_mask) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, query_length=query_length, ) # if is_decoder: # print(embedding_output.shape, attention_mask.shape, len(past_key_values)) #print(embedding_output.shape, extended_attention_mask.shape, encoder_hidden_states.shape, encoder_extended_attention_mask.shape) #print(extended_attention_mask[0], encoder_extended_attention_mask[0]) #print(query_embeds.shape, encoder_hidden_states.shape) sequence_output = encoder_outputs[0] pooled_output = (self.pooler(sequence_output) if self.pooler is not None else None) if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) class BertLMHeadModel(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] def __init__(self, config): super().__init__(config) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) self.init_weights() def get_output_embeddings(self): return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, past_key_values=None, use_cache=True, output_attentions=None, output_hidden_states=None, return_dict=None, return_logits=False, is_decoder=True, reduction="mean", ): r""" encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. use_cache (:obj:`bool`, `optional`): If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up decoding (see :obj:`past_key_values`). Returns: Example:: >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig >>> import torch >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') >>> config = BertConfig.from_pretrained("bert-base-cased") >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) >>> prediction_logits = outputs.logits """ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) if labels is not None: use_cache = False if past_key_values is not None: query_embeds = None #print(len(past_key_values)) #print('attention_mask', attention_mask) outputs = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, query_embeds=query_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, ) sequence_output = outputs[0] if query_embeds is not None: sequence_output = outputs[0][:, query_embeds.shape[1]:, :] prediction_scores = self.cls(sequence_output) if return_logits: return prediction_scores[:, :-1, :].contiguous() lm_loss = None if labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() labels = labels[:, 1:].contiguous() loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) lm_loss = loss_fct( shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1), ) if reduction == "none": lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) if not return_dict: output = (prediction_scores, ) + outputs[2:] return ((lm_loss, ) + output) if lm_loss is not None else output return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) def prepare_inputs_for_generation(self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs): # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) query_mask = input_ids.new_ones(query_embeds.shape[:-1]) attention_mask = torch.cat([query_mask, attention_mask], dim=-1) # cut decoder_input_ids if past is used if past is not None: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "query_embeds": query_embeds, "attention_mask": attention_mask, "past_key_values": past, "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), "is_decoder": True, } def _reorder_cache(self, past, beam_idx): reordered_past = () for layer_past in past: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past), ) return reordered_past class BertForMaskedLM(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] def __init__(self, config): super().__init__(config) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) self.init_weights() def get_output_embeddings(self): return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings def forward( self, input_ids=None, attention_mask=None, position_ids=None, head_mask=None, query_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, return_logits=False, is_decoder=False, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` """ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict) outputs = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, query_embeds=query_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, ) if query_embeds is not None: sequence_output = outputs[0][:, query_embeds.shape[1]:, :] prediction_scores = self.cls(sequence_output) if return_logits: return prediction_scores masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (prediction_scores, ) + outputs[2:] return (((masked_lm_loss, ) + output) if masked_lm_loss is not None else output) return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class Mlp(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks""" def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.attn_gradients = None self.attention_map = None def save_attn_gradients(self, attn_gradients): self.attn_gradients = attn_gradients def get_attn_gradients(self): return self.attn_gradients def save_attention_map(self, attention_map): self.attention_map = attention_map def get_attention_map(self): return self.attention_map def forward(self, x, register_hook=False): B, N, C = x.shape qkv = (self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)) q, k, v = ( qkv[0], qkv[1], qkv[2], ) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) if register_hook: self.save_attention_map(attn) attn.register_hook(self.save_attn_gradients) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False, ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) # if use_grad_checkpointing: # self.attn = checkpoint_wrapper(self.attn) # self.mlp = checkpoint_wrapper(self.mlp) def forward(self, x, register_hook=False): x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class VisionTransformer(nn.Module): """Vision Transformer A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, representation_size=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=None, use_grad_checkpointing=False, ckpt_layer=0, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels num_classes (int): number of classes for classification head embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True qk_scale (float): override default qk scale of head_dim ** -0.5 if set representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate norm_layer: (nn.Module): normalization layer """ super().__init__() self.num_features = (self.embed_dim) = embed_dim # num_features for consistency with other models norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, use_grad_checkpointing=(use_grad_checkpointing and i >= depth - ckpt_layer), ) for i in range(depth) ]) self.norm = norm_layer(embed_dim) trunc_normal_(self.pos_embed, std=0.02) trunc_normal_(self.cls_token, std=0.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {"pos_embed", "cls_token"} def forward(self, x, register_blk=-1): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed[:, :x.size(1), :] x = self.pos_drop(x) for i, blk in enumerate(self.blocks): x = blk(x, register_blk == i) x = self.norm(x) return x @torch.no_grad() def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ""): """Load weights from .npz checkpoints for official Google Brain Flax implementation""" import numpy as np def _n2p(w, t=True): if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: w = w.flatten() if t: if w.ndim == 4: w = w.transpose([3, 2, 0, 1]) elif w.ndim == 3: w = w.transpose([2, 0, 1]) elif w.ndim == 2: w = w.transpose([1, 0]) return torch.from_numpy(w) w = np.load(checkpoint_path) if not prefix and "opt/target/embedding/kernel" in w: prefix = "opt/target/" if hasattr(model.patch_embed, "backbone"): # hybrid backbone = model.patch_embed.backbone stem_only = not hasattr(backbone, "stem") stem = backbone if stem_only else backbone.stem stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f"{prefix}conv_root/kernel"]))) stem.norm.weight.copy_(_n2p(w[f"{prefix}gn_root/scale"])) stem.norm.bias.copy_(_n2p(w[f"{prefix}gn_root/bias"])) if not stem_only: for i, stage in enumerate(backbone.stages): for j, block in enumerate(stage.blocks): bp = f"{prefix}block{i + 1}/unit{j + 1}/" for r in range(3): getattr(block, f"conv{r + 1}").weight.copy_(_n2p(w[f"{bp}conv{r + 1}/kernel"])) getattr(block, f"norm{r + 1}").weight.copy_(_n2p(w[f"{bp}gn{r + 1}/scale"])) getattr(block, f"norm{r + 1}").bias.copy_(_n2p(w[f"{bp}gn{r + 1}/bias"])) if block.downsample is not None: block.downsample.conv.weight.copy_(_n2p(w[f"{bp}conv_proj/kernel"])) block.downsample.norm.weight.copy_(_n2p(w[f"{bp}gn_proj/scale"])) block.downsample.norm.bias.copy_(_n2p(w[f"{bp}gn_proj/bias"])) embed_conv_w = _n2p(w[f"{prefix}embedding/kernel"]) else: embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1], _n2p(w[f"{prefix}embedding/kernel"])) model.patch_embed.proj.weight.copy_(embed_conv_w) model.patch_embed.proj.bias.copy_(_n2p(w[f"{prefix}embedding/bias"])) model.cls_token.copy_(_n2p(w[f"{prefix}cls"], t=False)) pos_embed_w = _n2p(w[f"{prefix}Transformer/posembed_input/pos_embedding"], t=False) if pos_embed_w.shape != model.pos_embed.shape: pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights pos_embed_w, model.pos_embed, getattr(model, "num_tokens", 1), model.patch_embed.grid_size, ) model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/scale"])) model.norm.bias.copy_(_n2p(w[f"{prefix}Transformer/encoder_norm/bias"])) # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) for i, block in enumerate(model.blocks.children()): block_prefix = f"{prefix}Transformer/encoderblock_{i}/" mha_prefix = block_prefix + "MultiHeadDotProductAttention_1/" block.norm1.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/scale"])) block.norm1.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_0/bias"])) block.attn.qkv.weight.copy_( torch.cat([_n2p(w[f"{mha_prefix}{n}/kernel"], t=False).flatten(1).T for n in ("query", "key", "value")])) block.attn.qkv.bias.copy_( torch.cat([_n2p(w[f"{mha_prefix}{n}/bias"], t=False).reshape(-1) for n in ("query", "key", "value")])) block.attn.proj.weight.copy_(_n2p(w[f"{mha_prefix}out/kernel"]).flatten(1)) block.attn.proj.bias.copy_(_n2p(w[f"{mha_prefix}out/bias"])) for r in range(2): getattr(block.mlp, f"fc{r + 1}").weight.copy_(_n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/kernel"])) getattr(block.mlp, f"fc{r + 1}").bias.copy_(_n2p(w[f"{block_prefix}MlpBlock_3/Dense_{r}/bias"])) block.norm2.weight.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/scale"])) block.norm2.bias.copy_(_n2p(w[f"{block_prefix}LayerNorm_2/bias"])) def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 print("Resized position embedding: %s to %s", posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] if num_tokens: posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] ntok_new -= num_tokens else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) if not len(gs_new): # backwards compatibility gs_new = [int(math.sqrt(ntok_new))] * 2 assert len(gs_new) >= 2 print("Position embedding grid-size from %s to %s", [gs_old, gs_old], gs_new) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode="bicubic", align_corners=False) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): # interpolate position embedding embedding_size = pos_embed_checkpoint.shape[-1] num_patches = visual_encoder.patch_embed.num_patches num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) # height (== width) for the new position embedding new_size = int(num_patches**0.5) if orig_size != new_size: # class_token and dist_token are kept unchanged extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) print("reshape position embedding from %d to %d" % (orig_size**2, new_size**2)) return new_pos_embed else: return pos_embed_checkpoint # class Blip2Base(BaseModel): class Blip2Base(PreTrainedModel): config_class = BertConfig def __init__(self, config): super().__init__(config) @property def device(self): return list(self.parameters())[0].device @classmethod def init_tokenizer(cls, truncation_side="right"): tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side=truncation_side) tokenizer.add_special_tokens({"bos_token": "[DEC]"}) return tokenizer @classmethod def init_Qformer(cls, encoder_config, num_query_token, vision_width, cross_attention_freq=2, cache_dir=""): #print ("loading") encoder_config = BertConfig.from_pretrained("bert-base-uncased") encoder_config.encoder_width = vision_width # insert cross-attention layer every other block encoder_config.add_cross_attention = True encoder_config.cross_attention_freq = cross_attention_freq encoder_config.query_length = num_query_token Qformer = BertLMHeadModel(encoder_config) # .from_pretrained("bert-base-uncased", config=encoder_config, cache_dir=cache_dir) query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size)) query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) return Qformer, query_tokens class VectorQuantizer2(nn.Module): """ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix multiplications and allows for post-hoc remapping of indices. """ # NOTE: due to a bug the beta term was applied to the wrong term. for # backwards compatibility we use the buggy version by default, but you can # specify legacy=False to fix it. def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True): super().__init__() self.n_e = n_e self.e_dim = e_dim self.beta = beta self.legacy = legacy self.embedding = nn.Embedding(self.n_e, self.e_dim) self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) self.remap = remap if self.remap is not None: self.register_buffer("used", torch.tensor(np.load(self.remap))) self.re_embed = self.used.shape[0] self.unknown_index = unknown_index # "random" or "extra" or integer if self.unknown_index == "extra": self.unknown_index = self.re_embed self.re_embed = self.re_embed + 1 print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " f"Using {self.unknown_index} for unknown indices.") else: self.re_embed = n_e self.sane_index_shape = sane_index_shape def remap_to_used(self, inds): ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) match = (inds[:, :, None] == used[None, None, ...]).long() new = match.argmax(-1) unknown = match.sum(2) < 1 if self.unknown_index == "random": new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) else: new[unknown] = self.unknown_index return new.reshape(ishape) def unmap_to_all(self, inds): ishape = inds.shape assert len(ishape) > 1 inds = inds.reshape(ishape[0], -1) used = self.used.to(inds) if self.re_embed > self.used.shape[0]: # extra token inds[inds >= self.used.shape[0]] = 0 # simply set to zero back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) return back.reshape(ishape) # def l2norm(self, t): # return F.normalize(t, p = 2, dim = -1) def forward(self, z, temp=None, rescale_logits=False, return_logits=False): assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" assert rescale_logits is False, "Only for interface compatible with Gumbel" assert return_logits is False, "Only for interface compatible with Gumbel" # reshape z -> (batch, height, width, channel) and flatten #z = rearrange(z, 'b c h w -> b h w c').contiguous() bz = z.shape[0] z_flattened = z.view(-1, self.e_dim) #print('z_flattened', z_flattened.shape) # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ torch.sum(self.embedding.weight**2, dim=1) - 2 * \ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) min_encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(min_encoding_indices).view(z.shape) perplexity = None min_encodings = None # compute loss for embedding if not self.legacy: loss = self.beta * torch.mean((z_q.detach() - z)**2) + torch.mean((z_q - z.detach())**2) else: loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2) # preserve gradients z_q = z + (z_q - z).detach() # reshape back to match original input shape #z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() z_q = z_q.reshape(bz, -1, z_q.shape[-1]) if self.remap is not None: min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis min_encoding_indices = self.remap_to_used(min_encoding_indices) min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten if self.sane_index_shape: min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) return z_q, loss, min_encoding_indices def get_codebook_entry(self, indices, shape=None): # shape specifying (batch, height, width, channel) if self.remap is not None: indices = indices.reshape(shape[0], -1) # add batch axis indices = self.unmap_to_all(indices) indices = indices.reshape(-1) # flatten again # get quantized latent vectors z_q = self.embedding(indices) if shape is not None: z_q = z_q.view(shape) # reshape back to match original input shape z_q = z_q.permute(0, 3, 1, 2).contiguous() return z_q class Blip2QformerQuantizer(Blip2Base): """ BLIP2 first-stage model with Q-former and ViT. Supported model types: - pretrained: pretrained model with vit-g - pretrain_vitL: pretrained model with vit-large - coco: fintuned model on coco Usage: >>> from lavis.models import load_model >>> model = load_model("blip2", "pretrain") """ PRETRAINED_MODEL_CONFIG_DICT = { "pretrain": "configs/models/blip2/blip2_pretrain.yaml", "pretrain_vitL": "configs/models/blip2/blip2_pretrain_vitL.yaml", "coco": "configs/models/blip2/blip2_coco.yaml", } def __init__(self, config, img_size=224, drop_path_rate=0, use_grad_checkpoint=False, freeze_vit=True, num_query_token=32, cross_attention_freq=2, embed_dim=256, max_txt_len=32, codebook_embed_dim=32, n_embed=8192, recon_s=True, blocks_for_image=True, decode_depth=4, use_recon_s_for_image=False, image_features_dim=1024, visual_encoder_num_features=1408, cache_dir="./"): super().__init__(config) self.tokenizer = self.init_tokenizer() self.codebook_embed_dim = codebook_embed_dim self.n_embed = n_embed self.recon_s = recon_s self.blocks_for_image = blocks_for_image self.use_recon_s_for_image = use_recon_s_for_image self.depth = decode_depth self.image_features_dim = image_features_dim self.Qformer, self.query_tokens = self.init_Qformer(config, num_query_token, visual_encoder_num_features, cache_dir=cache_dir) self.Qformer.cls = None self.Qformer.bert.embeddings.word_embeddings = None self.Qformer.bert.embeddings.position_embeddings = None for layer in self.Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None for name, param in self.Qformer.named_parameters(): param.requires_grad = False self.query_tokens.requires_grad = False self.quantize = VectorQuantizer2(n_embed, codebook_embed_dim, beta=0.25, remap=None, sane_index_shape=False) self.encode_task_layer = nn.Sequential( nn.Linear(self.Qformer.config.hidden_size, self.Qformer.config.hidden_size), nn.Tanh(), nn.Linear(self.Qformer.config.hidden_size, codebook_embed_dim) # for quantize ) self.decode_task_layer = nn.Sequential( nn.Linear(codebook_embed_dim, codebook_embed_dim), nn.Tanh(), nn.Linear(codebook_embed_dim, self.Qformer.config.hidden_size) # for quantize ) self.quantize = self.quantize.eval() self.quantize.training = False for name, param in self.named_parameters(): if 'quantize' in name or 'encode_task_layer' in name or 'decode_task_layer' in name: #print('freeze params', name) param.requires_grad = False if self.recon_s: self.pos_embed = nn.Parameter(torch.zeros(1, num_query_token, self.Qformer.config.hidden_size)) self.blocks = nn.ModuleList([ Block(dim=self.Qformer.config.hidden_size, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=partial(nn.LayerNorm, eps=1e-6)) for i in range(self.depth) ]) if self.blocks_for_image: self.pos_embed_image = nn.Parameter(torch.zeros(1, num_query_token, self.Qformer.config.hidden_size)) self.blocks_image = nn.ModuleList([ Block(dim=self.Qformer.config.hidden_size, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=partial(nn.LayerNorm, eps=1e-6)) for i in range(self.depth) ]) self.image_down = nn.Sequential( nn.Linear(self.Qformer.config.hidden_size, 256, bias=False), nn.ReLU(), nn.Linear(256, 128, bias=False), nn.ReLU(), nn.Linear(128, 32, bias=False), ) self.distill_image_proj = nn.Linear(num_query_token * 32, image_features_dim) def get_codebook_indices_only(self, visual_encoder, image): with torch.no_grad(): image_embeds = visual_encoder.ln_vision(visual_encoder(image)) image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) query_output_down = self.encode_task_layer(query_output.last_hidden_state) quant, loss_embed, embed_ind = self.quantize(query_output_down) embed_ind = embed_ind.reshape(quant.shape[0], -1) return embed_ind def get_codebook_indices(self, visual_encoder, image): with torch.no_grad(): image_embeds = visual_encoder.ln_vision(visual_encoder(image)) image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) query_output_down = self.encode_task_layer(query_output.last_hidden_state) quant, loss_embed, embed_ind = self.quantize(query_output_down) embed_ind = embed_ind.reshape(quant.shape[0], -1) query_output_up = self.decode_task_layer(quant) return embed_ind, query_output_up def get_codebook_entry(self, indices): with torch.no_grad(): quant_embedding = self.quantize.get_codebook_entry(indices) # print('quant_embedding_shape: ', quant_embedding.shape) # print(self.decode_task_layer) # exit() query_output_up = self.decode_task_layer(quant_embedding) pos_embed_image = self.pos_embed_image.repeat(query_output_up.shape[0], 1, 1) query_output_up_pos_image = query_output_up + pos_embed_image for blk in self.blocks_image: query_output_up_pos_image = blk(query_output_up_pos_image) query_output_up = query_output_up_pos_image reverse_output = self.image_down(query_output_up) reverse_output = reverse_output.reshape(reverse_output.shape[0], -1) reverse_output_proj = self.distill_image_proj(reverse_output) return reverse_output_proj @classmethod def get_vision_encoder(cls,model_name="eva_vit_g", img_size=224, drop_path_rate=0, use_grad_checkpoint=False, precision="fp32", cache_dir="./"): visual_encoder = create_eva_vit_g(img_size, drop_path_rate, use_grad_checkpoint, precision, cache_dir=cache_dir) visual_encoder.ln_vision = nn.LayerNorm(visual_encoder.num_features) for name, param in visual_encoder.named_parameters(): param.requires_grad = False visual_encoder = visual_encoder.eval() visual_encoder.ln_vision.weight.requires_grad = False visual_encoder.ln_vision.bias.requires_grad = False return visual_encoder class Seed2Tokenizer(PreTrainedModel): config_class = BertConfig base_model_prefix = "model" def __init__(self, config, image_size=224, drop_path_rate=0.4): super().__init__(config) model = Blip2QformerQuantizer(config) # .from_pretrained(pretrained_model_path=model_path, # cache_dir=cache_dir, # **kwargs).eval() #model = model.to(device) processor = transforms.Compose([ transforms.Resize((image_size, image_size), interpolation=3), # transforms.Resize(image_size, interpolation=3), # transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) ]) shape_latents = torch.Size([1, 4, 96, 96]) self.register_buffer("latents",torch.randn(shape_latents, generator=None, layout=torch.strided)) self.model = model self.processor = processor self.visual_encoder = VisionTransformerEvaVit( img_size=image_size, patch_size=14, use_mean_pooling=False, embed_dim=1408, depth=39, num_heads=1408 // 88, mlp_ratio=4.3637, qkv_bias=True, drop_path_rate=drop_path_rate, norm_layer=partial(nn.LayerNorm, eps=1e-6), use_checkpoint=False, ) def __len__(self): return self.model.n_embed def encode(self, image_torch, visual_encoder=None): '''Convert a batch of img to code Args: model: The tokenizer model. img: [b, c, h, w] ''' if visual_encoder is None: visual_encoder = self.visual_encoder if len(image_torch.shape) == 3: image_torch = image_torch.unsqueeze(0) image_torch = image_torch.to(dtype=self.latents.dtype) image_torch = image_torch.to(self.device) # img = image_torch.to(self.device) img = image_torch #if self.fp16: # img = img.half() #print (img.dtype) with torch.no_grad(): id = self.model.get_codebook_indices_only(visual_encoder, img) return id.view(img.shape[0], -1) def decode(self, diffusion_model, indices, guidance_scale=10, noise_level=0, num_inference_steps=20,): image_embeds = self.model.get_codebook_entry(indices) image_embeds = image_embeds.to(dtype=diffusion_model.dtype, device=diffusion_model.device) image = diffusion_model( image_embeds=image_embeds, guidance_scale=guidance_scale, noise_level=noise_level, num_inference_steps=num_inference_steps, latents=self.latents.to(dtype=diffusion_model.dtype, device=diffusion_model.device), ).images return image @property def num_image_tokens(self): return 8192 # self.image_tokenizer.num_tokens # allow not load def encode_image( self, image_path=None, image_pil=None, image_torch=None, image_size: int = 224, visual_encoder = None, ): assert (image_path is None) + (image_pil is None) + (image_torch is None) == 2 if visual_encoder is None: visual_encoder = self.visual_encoder # need_norm_to_1 = False if image_path is not None: image_pil = Image.open(image_path).convert('RGB') if image_pil is not None: image_torch = self.processor(image_pil) image_torch = image_torch.to(self.device) image_torch = image_torch.to(dtype=self.latents.dtype) return self.encode(image_torch, visual_encoder)