MIPHEI-vit-demo / model.py
U1020040
first commit
3047e70
"""
This script defines the MIPHEI-ViT architecture for image-to-image translation
Some modules in this file are adapted from: https://github.com/hustvl/ViTMatte/
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from timm.models import VisionTransformer, SwinTransformer
from timm.models import load_state_dict_from_hf
class Basic_Conv3x3(nn.Module):
"""
Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
https://github.com/hustvl/ViTMatte/blob/main/modeling/decoder/detail_capture.py#L5
"""
def __init__(
self,
in_chans,
out_chans,
stride=2,
padding=1,
):
super().__init__()
self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
self.bn = nn.BatchNorm2d(out_chans)
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class ConvStream(nn.Module):
"""
Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
"""
def __init__(
self,
in_chans = 4,
out_chans = [48, 96, 192],
):
super().__init__()
self.convs = nn.ModuleList()
self.conv_chans = out_chans.copy()
self.conv_chans.insert(0, in_chans)
for i in range(len(self.conv_chans)-1):
in_chan_ = self.conv_chans[i]
out_chan_ = self.conv_chans[i+1]
self.convs.append(
Basic_Conv3x3(in_chan_, out_chan_)
)
def forward(self, x):
out_dict = {'D0': x}
for i in range(len(self.convs)):
x = self.convs[i](x)
name_ = 'D'+str(i+1)
out_dict[name_] = x
return out_dict
class SegmentationHead(nn.Sequential):
# https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/base/heads.py#L5
def __init__(
self, in_channels, out_channels, kernel_size=3, activation=None, use_attention=False,
):
if use_attention:
attention = AttentionBlock(in_channels)
else:
attention = nn.Identity()
conv2d = nn.Conv2d(
in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
)
activation = activation
super().__init__(attention, conv2d, activation)
class AttentionBlock(nn.Module):
"""
Attention gate
Parameters:
-----------
in_chns : int
Number of input channels.
Forward Input:
--------------
x : torch.Tensor
Input tensor of shape [B, C, H, W].
Returns:
--------
torch.Tensor
Reweighted tensor of the same shape as input.
"""
def __init__(self, in_chns):
super(AttentionBlock, self).__init__()
# Attention generation
self.psi = nn.Sequential(
nn.Conv2d(in_chns, in_chns // 2, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(in_chns // 2),
nn.ReLU(),
nn.Conv2d(in_chns // 2, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
# Project decoder output to intermediate space
g = self.psi(x)
return x * g
class Fusion_Block(nn.Module):
"""
Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer.
"""
def __init__(
self,
in_chans,
out_chans,
):
super().__init__()
self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1)
def forward(self, x, D):
F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) ## Nearest ?
out = torch.cat([D, F_up], dim=1)
out = self.conv(out)
return out
class MIPHEIViT(nn.Module):
"""
U-Net-style architecture inspired by ViTMatte, using a Vision Transformer (ViT or Swin)
as encoder and a convolutional decoder. Designed for dense image prediction tasks,
such as image-to-image translation.
Parameters:
-----------
encoder : nn.Module
A ViT- or Swin-based encoder that outputs spatial feature maps.
decoder : nn.Module
A decoder module that maps encoder features (and optionally the original image)
to the output prediction.
Example:
--------
model = MIPHEIViT(encoder=Encoder(vit), decoder=UNetDecoder())
output = model(input_tensor)
"""
def __init__(self,
encoder,
decoder,
):
super(MIPHEIViT, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.initialize()
def forward(self, x):
features = self.encoder(x)
outputs = self.decoder(features, x)
return outputs
def initialize(self):
pass
@classmethod
def from_pretrained_hf(cls, repo_path=None, repo_id=None):
from safetensors.torch import load_file
import json
if repo_path:
weights_path = os.path.join(repo_path, "model.safetensors")
config_path = os.path.join(repo_path, "config_hf.json")
else:
from huggingface_hub import hf_hub_download
weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json")
# Load config values
with open(config_path, "r") as f:
config = json.load(f)
img_size = config["img_size"]
nc_out = len(config["targ_channel_names"])
use_attention = config["use_attention"]
hoptimus_hf_id = config["hoptimus_hf_id"]
vit = get_hoptimus0_hf(hoptimus_hf_id)
vit.set_input_size(img_size=(img_size, img_size))
encoder = Encoder(vit)
decoder = Detail_Capture(emb_chans=encoder.embed_dim, out_chans=nc_out, use_attention=use_attention, activation=nn.Tanh())
model = cls(encoder=encoder, decoder=decoder)
state_dict = load_file(weights_path)
state_dict = merge_lora_weights(model, state_dict)
load_info = model.load_state_dict(state_dict, strict=False)
validate_load_info(load_info)
model.eval()
return model
def set_input_size(self, img_size):
if any((s & (s - 1)) != 0 or s == 0 for s in img_size):
raise ValueError("Both height and width in img_size must be powers of 2")
if any(s < 128 for s in img_size):
raise ValueError("Height and width must be greater or equal to 128")
self.encoder.vit.set_input_size(img_size=img_size)
self.encoder.grid_size = self.encoder.vit.patch_embed.grid_size
class Encoder(nn.Module):
"""
Wraps a Vision Transformer (ViT or Swin) to produce feature maps compatible
with U-Net-like architectures. It reshapes and resizes transformer outputs
into spatial feature maps.
Parameters:
-----------
vit : VisionTransformer or SwinTransformer
A pretrained transformer model from `timm` that outputs patch embeddings.
"""
def __init__(self, vit):
super().__init__()
if not isinstance(vit, (VisionTransformer, SwinTransformer)):
raise ValueError(f"Expected a VisionTransformer or SwinTransformer, got {type(vit)}")
self.vit = vit
self.is_swint = isinstance(vit, SwinTransformer)
self.grid_size = self.vit.patch_embed.grid_size
if self.is_swint:
self.num_prefix_tokens = 0
self.embed_dim = self.vit.embed_dim * 2 ** (self.vit.num_layers -1)
else:
self.num_prefix_tokens = self.vit.num_prefix_tokens
self.embed_dim = self.vit.embed_dim
patch_size = self.vit.patch_embed.patch_size
img_size = self.vit.patch_embed.img_size
assert img_size[0] % 16 == 0
assert img_size[1] % 16 == 0
if self.is_swint:
self.scale_factor = (2., 2.)
else:
if patch_size != (16, 16):
target_grid_size = (img_size[0] / 16, img_size[1] / 16)
self.scale_factor = (target_grid_size[0] / self.grid_size[0], target_grid_size[1] / self.grid_size[1])
else:
self.scale_factor = None
def forward(self, x):
features = self.vit(x)
if self.is_swint:
features = features.permute(0, 3, 1, 2)
else:
features = features[:, self.num_prefix_tokens:]
features = features.permute(0, 2, 1)
features = features.view((-1, self.embed_dim, *self.grid_size))
if self.scale_factor is not None:
features = F.interpolate(features, scale_factor=self.scale_factor, mode="bicubic")
return features
class Detail_Capture(nn.Module):
"""
Simple and Lightweight Detail Capture Module for ViT Matting.
"""
def __init__(
self,
emb_chans,
in_chans=3,
out_chans=1,
convstream_out = [48, 96, 192],
fusion_out = [256, 128, 64, 32],
use_attention=True,
activation=torch.nn.Identity()
):
super().__init__()
assert len(fusion_out) == len(convstream_out) + 1
self.convstream = ConvStream(in_chans=in_chans)
self.conv_chans = self.convstream.conv_chans
self.num_heads = out_chans
self.fusion_blks = nn.ModuleList()
self.fus_channs = fusion_out.copy()
self.fus_channs.insert(0, emb_chans)
for i in range(len(self.fus_channs)-1):
self.fusion_blks.append(
Fusion_Block(
in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)],
out_chans = self.fus_channs[i+1],
)
)
for idx in range(self.num_heads):
setattr(self, f'segmentation_head_{idx}', SegmentationHead(
in_channels=fusion_out[-1],
out_channels=1,
activation=activation,
kernel_size=3,
use_attention=use_attention
))
def forward(self, features, images):
detail_features = self.convstream(images)
for i in range(len(self.fusion_blks)):
d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
features = self.fusion_blks[i](features, detail_features[d_name_])
outputs = []
for idx_head in range(self.num_heads):
segmentation_head = getattr(self, f'segmentation_head_{idx_head}')
output = segmentation_head(features)
outputs.append(output)
outputs = torch.cat(outputs, dim=1)
return outputs
def merge_lora_weights(model, state_dict, alpha=1.0, block_prefix="encoder.vit.blocks"):
"""
Merges LoRA weights into the base attention Q and V projection weights for each transformer block.
We keep LoRA weights in the model.safetensors to avoid having the original foundation model weights in the repo.
Parameters:
-----------
model : torch.nn.Module
The model containing the transformer blocks to modify (e.g., ViT backbone).
state_dict : dict
The state_dict containing LoRA matrices with keys formatted as
'{block_prefix}.{idx}.attn.qkv.lora_q.A', etc.
This dict is modified in-place to remove LoRA weights after merging.
alpha : float, optional
Scaling factor for the LoRA update. Defaults to 1.0.
block_prefix : str, optional
Prefix to locate transformer blocks in the model. Defaults to "encoder.vit.blocks".
Returns:
--------
dict
The modified state_dict with LoRA weights removed after merging.
"""
with torch.no_grad():
for idx in range(len(model.encoder.vit.blocks)):
prefix = f"{block_prefix}.{idx}.attn.qkv"
# Extract LoRA matrices
A_q = state_dict.pop(f"{prefix}.lora_q.A")
B_q = state_dict.pop(f"{prefix}.lora_q.B")
A_v = state_dict.pop(f"{prefix}.lora_v.A")
B_v = state_dict.pop(f"{prefix}.lora_v.B")
# Compute low-rank updates (transposed to match weight shape)
delta_q = (alpha * A_q @ B_q).T
delta_v = (alpha * A_v @ B_v).T
# Get original QKV weight matrix (shape: [3*dim, dim])
W = model.get_parameter(f"{prefix}.weight")
dim = delta_q.shape[0]
assert W.shape[0] == 3 * dim, f"Unexpected QKV shape: {W.shape}"
# Apply LoRA deltas to Q and V projections
W[:dim, :] += delta_q # Q projection
W[2 * dim:, :] += delta_v # V projection
return state_dict
def get_hoptimus0_hf(repo_id):
""" Hoptimus foundation model from hugginface repo id
"""
model = timm.create_model(
"vit_giant_patch14_reg4_dinov2", img_size=224,
drop_path_rate=0., num_classes=0,
global_pool="", pretrained=False, init_values=1e-5,
dynamic_img_size=False)
state_dict = load_state_dict_from_hf(repo_id, weights_only=True)
model.load_state_dict(state_dict)
return model
def validate_load_info(load_info):
"""
Validates the result of model.load_state_dict(..., strict=False).
Raises:
ValueError if unexpected keys are found,
or if missing keys are not related to the allowed encoder modules.
"""
# 1. Raise if any unexpected keys
if load_info.unexpected_keys:
raise ValueError(f"Unexpected keys in state_dict: {load_info.unexpected_keys}")
# 2. Raise if any missing keys are not part of allowed encoder modules
for key in load_info.missing_keys:
if ".lora" in key:
raise ValueError(f"Missing LoRA checkpoint in state_dict: {key}")
elif not any(part in key for part in ["encoder.vit.", "encoder.model."]):
raise ValueError(f"Missing key in state_dict: {key}")