""" This script defines the MIPHEI-ViT architecture for image-to-image translation Some modules in this file are adapted from: https://github.com/hustvl/ViTMatte/ """ import os import torch import torch.nn as nn import torch.nn.functional as F import timm from timm.models import VisionTransformer, SwinTransformer from timm.models import load_state_dict_from_hf class Basic_Conv3x3(nn.Module): """ Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers. https://github.com/hustvl/ViTMatte/blob/main/modeling/decoder/detail_capture.py#L5 """ def __init__( self, in_chans, out_chans, stride=2, padding=1, ): super().__init__() self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False) self.bn = nn.BatchNorm2d(out_chans) self.relu = nn.ReLU(inplace=False) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class ConvStream(nn.Module): """ Simple ConvStream containing a series of basic conv3x3 layers to extract detail features. """ def __init__( self, in_chans = 4, out_chans = [48, 96, 192], ): super().__init__() self.convs = nn.ModuleList() self.conv_chans = out_chans.copy() self.conv_chans.insert(0, in_chans) for i in range(len(self.conv_chans)-1): in_chan_ = self.conv_chans[i] out_chan_ = self.conv_chans[i+1] self.convs.append( Basic_Conv3x3(in_chan_, out_chan_) ) def forward(self, x): out_dict = {'D0': x} for i in range(len(self.convs)): x = self.convs[i](x) name_ = 'D'+str(i+1) out_dict[name_] = x return out_dict class SegmentationHead(nn.Sequential): # https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/base/heads.py#L5 def __init__( self, in_channels, out_channels, kernel_size=3, activation=None, use_attention=False, ): if use_attention: attention = AttentionBlock(in_channels) else: attention = nn.Identity() conv2d = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 ) activation = activation super().__init__(attention, conv2d, activation) class AttentionBlock(nn.Module): """ Attention gate Parameters: ----------- in_chns : int Number of input channels. Forward Input: -------------- x : torch.Tensor Input tensor of shape [B, C, H, W]. Returns: -------- torch.Tensor Reweighted tensor of the same shape as input. """ def __init__(self, in_chns): super(AttentionBlock, self).__init__() # Attention generation self.psi = nn.Sequential( nn.Conv2d(in_chns, in_chns // 2, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(in_chns // 2), nn.ReLU(), nn.Conv2d(in_chns // 2, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.Sigmoid() ) def forward(self, x): # Project decoder output to intermediate space g = self.psi(x) return x * g class Fusion_Block(nn.Module): """ Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer. """ def __init__( self, in_chans, out_chans, ): super().__init__() self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1) def forward(self, x, D): F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) ## Nearest ? out = torch.cat([D, F_up], dim=1) out = self.conv(out) return out class MIPHEIViT(nn.Module): """ U-Net-style architecture inspired by ViTMatte, using a Vision Transformer (ViT or Swin) as encoder and a convolutional decoder. Designed for dense image prediction tasks, such as image-to-image translation. Parameters: ----------- encoder : nn.Module A ViT- or Swin-based encoder that outputs spatial feature maps. decoder : nn.Module A decoder module that maps encoder features (and optionally the original image) to the output prediction. Example: -------- model = MIPHEIViT(encoder=Encoder(vit), decoder=UNetDecoder()) output = model(input_tensor) """ def __init__(self, encoder, decoder, ): super(MIPHEIViT, self).__init__() self.encoder = encoder self.decoder = decoder self.initialize() def forward(self, x): features = self.encoder(x) outputs = self.decoder(features, x) return outputs def initialize(self): pass @classmethod def from_pretrained_hf(cls, repo_path=None, repo_id=None): from safetensors.torch import load_file import json if repo_path: weights_path = os.path.join(repo_path, "model.safetensors") config_path = os.path.join(repo_path, "config_hf.json") else: from huggingface_hub import hf_hub_download weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors") config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json") # Load config values with open(config_path, "r") as f: config = json.load(f) img_size = config["img_size"] nc_out = len(config["targ_channel_names"]) use_attention = config["use_attention"] hoptimus_hf_id = config["hoptimus_hf_id"] vit = get_hoptimus0_hf(hoptimus_hf_id) vit.set_input_size(img_size=(img_size, img_size)) encoder = Encoder(vit) decoder = Detail_Capture(emb_chans=encoder.embed_dim, out_chans=nc_out, use_attention=use_attention, activation=nn.Tanh()) model = cls(encoder=encoder, decoder=decoder) state_dict = load_file(weights_path) state_dict = merge_lora_weights(model, state_dict) load_info = model.load_state_dict(state_dict, strict=False) validate_load_info(load_info) model.eval() return model def set_input_size(self, img_size): if any((s & (s - 1)) != 0 or s == 0 for s in img_size): raise ValueError("Both height and width in img_size must be powers of 2") if any(s < 128 for s in img_size): raise ValueError("Height and width must be greater or equal to 128") self.encoder.vit.set_input_size(img_size=img_size) self.encoder.grid_size = self.encoder.vit.patch_embed.grid_size class Encoder(nn.Module): """ Wraps a Vision Transformer (ViT or Swin) to produce feature maps compatible with U-Net-like architectures. It reshapes and resizes transformer outputs into spatial feature maps. Parameters: ----------- vit : VisionTransformer or SwinTransformer A pretrained transformer model from `timm` that outputs patch embeddings. """ def __init__(self, vit): super().__init__() if not isinstance(vit, (VisionTransformer, SwinTransformer)): raise ValueError(f"Expected a VisionTransformer or SwinTransformer, got {type(vit)}") self.vit = vit self.is_swint = isinstance(vit, SwinTransformer) self.grid_size = self.vit.patch_embed.grid_size if self.is_swint: self.num_prefix_tokens = 0 self.embed_dim = self.vit.embed_dim * 2 ** (self.vit.num_layers -1) else: self.num_prefix_tokens = self.vit.num_prefix_tokens self.embed_dim = self.vit.embed_dim patch_size = self.vit.patch_embed.patch_size img_size = self.vit.patch_embed.img_size assert img_size[0] % 16 == 0 assert img_size[1] % 16 == 0 if self.is_swint: self.scale_factor = (2., 2.) else: if patch_size != (16, 16): target_grid_size = (img_size[0] / 16, img_size[1] / 16) self.scale_factor = (target_grid_size[0] / self.grid_size[0], target_grid_size[1] / self.grid_size[1]) else: self.scale_factor = None def forward(self, x): features = self.vit(x) if self.is_swint: features = features.permute(0, 3, 1, 2) else: features = features[:, self.num_prefix_tokens:] features = features.permute(0, 2, 1) features = features.view((-1, self.embed_dim, *self.grid_size)) if self.scale_factor is not None: features = F.interpolate(features, scale_factor=self.scale_factor, mode="bicubic") return features class Detail_Capture(nn.Module): """ Simple and Lightweight Detail Capture Module for ViT Matting. """ def __init__( self, emb_chans, in_chans=3, out_chans=1, convstream_out = [48, 96, 192], fusion_out = [256, 128, 64, 32], use_attention=True, activation=torch.nn.Identity() ): super().__init__() assert len(fusion_out) == len(convstream_out) + 1 self.convstream = ConvStream(in_chans=in_chans) self.conv_chans = self.convstream.conv_chans self.num_heads = out_chans self.fusion_blks = nn.ModuleList() self.fus_channs = fusion_out.copy() self.fus_channs.insert(0, emb_chans) for i in range(len(self.fus_channs)-1): self.fusion_blks.append( Fusion_Block( in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)], out_chans = self.fus_channs[i+1], ) ) for idx in range(self.num_heads): setattr(self, f'segmentation_head_{idx}', SegmentationHead( in_channels=fusion_out[-1], out_channels=1, activation=activation, kernel_size=3, use_attention=use_attention )) def forward(self, features, images): detail_features = self.convstream(images) for i in range(len(self.fusion_blks)): d_name_ = 'D'+str(len(self.fusion_blks)-i-1) features = self.fusion_blks[i](features, detail_features[d_name_]) outputs = [] for idx_head in range(self.num_heads): segmentation_head = getattr(self, f'segmentation_head_{idx_head}') output = segmentation_head(features) outputs.append(output) outputs = torch.cat(outputs, dim=1) return outputs def merge_lora_weights(model, state_dict, alpha=1.0, block_prefix="encoder.vit.blocks"): """ Merges LoRA weights into the base attention Q and V projection weights for each transformer block. We keep LoRA weights in the model.safetensors to avoid having the original foundation model weights in the repo. Parameters: ----------- model : torch.nn.Module The model containing the transformer blocks to modify (e.g., ViT backbone). state_dict : dict The state_dict containing LoRA matrices with keys formatted as '{block_prefix}.{idx}.attn.qkv.lora_q.A', etc. This dict is modified in-place to remove LoRA weights after merging. alpha : float, optional Scaling factor for the LoRA update. Defaults to 1.0. block_prefix : str, optional Prefix to locate transformer blocks in the model. Defaults to "encoder.vit.blocks". Returns: -------- dict The modified state_dict with LoRA weights removed after merging. """ with torch.no_grad(): for idx in range(len(model.encoder.vit.blocks)): prefix = f"{block_prefix}.{idx}.attn.qkv" # Extract LoRA matrices A_q = state_dict.pop(f"{prefix}.lora_q.A") B_q = state_dict.pop(f"{prefix}.lora_q.B") A_v = state_dict.pop(f"{prefix}.lora_v.A") B_v = state_dict.pop(f"{prefix}.lora_v.B") # Compute low-rank updates (transposed to match weight shape) delta_q = (alpha * A_q @ B_q).T delta_v = (alpha * A_v @ B_v).T # Get original QKV weight matrix (shape: [3*dim, dim]) W = model.get_parameter(f"{prefix}.weight") dim = delta_q.shape[0] assert W.shape[0] == 3 * dim, f"Unexpected QKV shape: {W.shape}" # Apply LoRA deltas to Q and V projections W[:dim, :] += delta_q # Q projection W[2 * dim:, :] += delta_v # V projection return state_dict def get_hoptimus0_hf(repo_id): """ Hoptimus foundation model from hugginface repo id """ model = timm.create_model( "vit_giant_patch14_reg4_dinov2", img_size=224, drop_path_rate=0., num_classes=0, global_pool="", pretrained=False, init_values=1e-5, dynamic_img_size=False) state_dict = load_state_dict_from_hf(repo_id, weights_only=True) model.load_state_dict(state_dict) return model def validate_load_info(load_info): """ Validates the result of model.load_state_dict(..., strict=False). Raises: ValueError if unexpected keys are found, or if missing keys are not related to the allowed encoder modules. """ # 1. Raise if any unexpected keys if load_info.unexpected_keys: raise ValueError(f"Unexpected keys in state_dict: {load_info.unexpected_keys}") # 2. Raise if any missing keys are not part of allowed encoder modules for key in load_info.missing_keys: if ".lora" in key: raise ValueError(f"Missing LoRA checkpoint in state_dict: {key}") elif not any(part in key for part in ["encoder.vit.", "encoder.model."]): raise ValueError(f"Missing key in state_dict: {key}")