Spaces:
Running
Running
""" | |
This script defines the MIPHEI-ViT architecture for image-to-image translation | |
Some modules in this file are adapted from: https://github.com/hustvl/ViTMatte/ | |
""" | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import timm | |
from timm.models import VisionTransformer, SwinTransformer | |
from timm.models import load_state_dict_from_hf | |
class Basic_Conv3x3(nn.Module): | |
""" | |
Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers. | |
https://github.com/hustvl/ViTMatte/blob/main/modeling/decoder/detail_capture.py#L5 | |
""" | |
def __init__( | |
self, | |
in_chans, | |
out_chans, | |
stride=2, | |
padding=1, | |
): | |
super().__init__() | |
self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False) | |
self.bn = nn.BatchNorm2d(out_chans) | |
self.relu = nn.ReLU(inplace=False) | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
x = self.relu(x) | |
return x | |
class ConvStream(nn.Module): | |
""" | |
Simple ConvStream containing a series of basic conv3x3 layers to extract detail features. | |
""" | |
def __init__( | |
self, | |
in_chans = 4, | |
out_chans = [48, 96, 192], | |
): | |
super().__init__() | |
self.convs = nn.ModuleList() | |
self.conv_chans = out_chans.copy() | |
self.conv_chans.insert(0, in_chans) | |
for i in range(len(self.conv_chans)-1): | |
in_chan_ = self.conv_chans[i] | |
out_chan_ = self.conv_chans[i+1] | |
self.convs.append( | |
Basic_Conv3x3(in_chan_, out_chan_) | |
) | |
def forward(self, x): | |
out_dict = {'D0': x} | |
for i in range(len(self.convs)): | |
x = self.convs[i](x) | |
name_ = 'D'+str(i+1) | |
out_dict[name_] = x | |
return out_dict | |
class SegmentationHead(nn.Sequential): | |
# https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/base/heads.py#L5 | |
def __init__( | |
self, in_channels, out_channels, kernel_size=3, activation=None, use_attention=False, | |
): | |
if use_attention: | |
attention = AttentionBlock(in_channels) | |
else: | |
attention = nn.Identity() | |
conv2d = nn.Conv2d( | |
in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 | |
) | |
activation = activation | |
super().__init__(attention, conv2d, activation) | |
class AttentionBlock(nn.Module): | |
""" | |
Attention gate | |
Parameters: | |
----------- | |
in_chns : int | |
Number of input channels. | |
Forward Input: | |
-------------- | |
x : torch.Tensor | |
Input tensor of shape [B, C, H, W]. | |
Returns: | |
-------- | |
torch.Tensor | |
Reweighted tensor of the same shape as input. | |
""" | |
def __init__(self, in_chns): | |
super(AttentionBlock, self).__init__() | |
# Attention generation | |
self.psi = nn.Sequential( | |
nn.Conv2d(in_chns, in_chns // 2, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.BatchNorm2d(in_chns // 2), | |
nn.ReLU(), | |
nn.Conv2d(in_chns // 2, 1, kernel_size=1, stride=1, padding=0, bias=True), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
# Project decoder output to intermediate space | |
g = self.psi(x) | |
return x * g | |
class Fusion_Block(nn.Module): | |
""" | |
Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer. | |
""" | |
def __init__( | |
self, | |
in_chans, | |
out_chans, | |
): | |
super().__init__() | |
self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1) | |
def forward(self, x, D): | |
F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) ## Nearest ? | |
out = torch.cat([D, F_up], dim=1) | |
out = self.conv(out) | |
return out | |
class MIPHEIViT(nn.Module): | |
""" | |
U-Net-style architecture inspired by ViTMatte, using a Vision Transformer (ViT or Swin) | |
as encoder and a convolutional decoder. Designed for dense image prediction tasks, | |
such as image-to-image translation. | |
Parameters: | |
----------- | |
encoder : nn.Module | |
A ViT- or Swin-based encoder that outputs spatial feature maps. | |
decoder : nn.Module | |
A decoder module that maps encoder features (and optionally the original image) | |
to the output prediction. | |
Example: | |
-------- | |
model = MIPHEIViT(encoder=Encoder(vit), decoder=UNetDecoder()) | |
output = model(input_tensor) | |
""" | |
def __init__(self, | |
encoder, | |
decoder, | |
): | |
super(MIPHEIViT, self).__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.initialize() | |
def forward(self, x): | |
features = self.encoder(x) | |
outputs = self.decoder(features, x) | |
return outputs | |
def initialize(self): | |
pass | |
def from_pretrained_hf(cls, repo_path=None, repo_id=None): | |
from safetensors.torch import load_file | |
import json | |
if repo_path: | |
weights_path = os.path.join(repo_path, "model.safetensors") | |
config_path = os.path.join(repo_path, "config_hf.json") | |
else: | |
from huggingface_hub import hf_hub_download | |
weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors") | |
config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json") | |
# Load config values | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
img_size = config["img_size"] | |
nc_out = len(config["targ_channel_names"]) | |
use_attention = config["use_attention"] | |
hoptimus_hf_id = config["hoptimus_hf_id"] | |
vit = get_hoptimus0_hf(hoptimus_hf_id) | |
vit.set_input_size(img_size=(img_size, img_size)) | |
encoder = Encoder(vit) | |
decoder = Detail_Capture(emb_chans=encoder.embed_dim, out_chans=nc_out, use_attention=use_attention, activation=nn.Tanh()) | |
model = cls(encoder=encoder, decoder=decoder) | |
state_dict = load_file(weights_path) | |
state_dict = merge_lora_weights(model, state_dict) | |
load_info = model.load_state_dict(state_dict, strict=False) | |
validate_load_info(load_info) | |
model.eval() | |
return model | |
def set_input_size(self, img_size): | |
if any((s & (s - 1)) != 0 or s == 0 for s in img_size): | |
raise ValueError("Both height and width in img_size must be powers of 2") | |
if any(s < 128 for s in img_size): | |
raise ValueError("Height and width must be greater or equal to 128") | |
self.encoder.vit.set_input_size(img_size=img_size) | |
self.encoder.grid_size = self.encoder.vit.patch_embed.grid_size | |
class Encoder(nn.Module): | |
""" | |
Wraps a Vision Transformer (ViT or Swin) to produce feature maps compatible | |
with U-Net-like architectures. It reshapes and resizes transformer outputs | |
into spatial feature maps. | |
Parameters: | |
----------- | |
vit : VisionTransformer or SwinTransformer | |
A pretrained transformer model from `timm` that outputs patch embeddings. | |
""" | |
def __init__(self, vit): | |
super().__init__() | |
if not isinstance(vit, (VisionTransformer, SwinTransformer)): | |
raise ValueError(f"Expected a VisionTransformer or SwinTransformer, got {type(vit)}") | |
self.vit = vit | |
self.is_swint = isinstance(vit, SwinTransformer) | |
self.grid_size = self.vit.patch_embed.grid_size | |
if self.is_swint: | |
self.num_prefix_tokens = 0 | |
self.embed_dim = self.vit.embed_dim * 2 ** (self.vit.num_layers -1) | |
else: | |
self.num_prefix_tokens = self.vit.num_prefix_tokens | |
self.embed_dim = self.vit.embed_dim | |
patch_size = self.vit.patch_embed.patch_size | |
img_size = self.vit.patch_embed.img_size | |
assert img_size[0] % 16 == 0 | |
assert img_size[1] % 16 == 0 | |
if self.is_swint: | |
self.scale_factor = (2., 2.) | |
else: | |
if patch_size != (16, 16): | |
target_grid_size = (img_size[0] / 16, img_size[1] / 16) | |
self.scale_factor = (target_grid_size[0] / self.grid_size[0], target_grid_size[1] / self.grid_size[1]) | |
else: | |
self.scale_factor = None | |
def forward(self, x): | |
features = self.vit(x) | |
if self.is_swint: | |
features = features.permute(0, 3, 1, 2) | |
else: | |
features = features[:, self.num_prefix_tokens:] | |
features = features.permute(0, 2, 1) | |
features = features.view((-1, self.embed_dim, *self.grid_size)) | |
if self.scale_factor is not None: | |
features = F.interpolate(features, scale_factor=self.scale_factor, mode="bicubic") | |
return features | |
class Detail_Capture(nn.Module): | |
""" | |
Simple and Lightweight Detail Capture Module for ViT Matting. | |
""" | |
def __init__( | |
self, | |
emb_chans, | |
in_chans=3, | |
out_chans=1, | |
convstream_out = [48, 96, 192], | |
fusion_out = [256, 128, 64, 32], | |
use_attention=True, | |
activation=torch.nn.Identity() | |
): | |
super().__init__() | |
assert len(fusion_out) == len(convstream_out) + 1 | |
self.convstream = ConvStream(in_chans=in_chans) | |
self.conv_chans = self.convstream.conv_chans | |
self.num_heads = out_chans | |
self.fusion_blks = nn.ModuleList() | |
self.fus_channs = fusion_out.copy() | |
self.fus_channs.insert(0, emb_chans) | |
for i in range(len(self.fus_channs)-1): | |
self.fusion_blks.append( | |
Fusion_Block( | |
in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)], | |
out_chans = self.fus_channs[i+1], | |
) | |
) | |
for idx in range(self.num_heads): | |
setattr(self, f'segmentation_head_{idx}', SegmentationHead( | |
in_channels=fusion_out[-1], | |
out_channels=1, | |
activation=activation, | |
kernel_size=3, | |
use_attention=use_attention | |
)) | |
def forward(self, features, images): | |
detail_features = self.convstream(images) | |
for i in range(len(self.fusion_blks)): | |
d_name_ = 'D'+str(len(self.fusion_blks)-i-1) | |
features = self.fusion_blks[i](features, detail_features[d_name_]) | |
outputs = [] | |
for idx_head in range(self.num_heads): | |
segmentation_head = getattr(self, f'segmentation_head_{idx_head}') | |
output = segmentation_head(features) | |
outputs.append(output) | |
outputs = torch.cat(outputs, dim=1) | |
return outputs | |
def merge_lora_weights(model, state_dict, alpha=1.0, block_prefix="encoder.vit.blocks"): | |
""" | |
Merges LoRA weights into the base attention Q and V projection weights for each transformer block. | |
We keep LoRA weights in the model.safetensors to avoid having the original foundation model weights in the repo. | |
Parameters: | |
----------- | |
model : torch.nn.Module | |
The model containing the transformer blocks to modify (e.g., ViT backbone). | |
state_dict : dict | |
The state_dict containing LoRA matrices with keys formatted as | |
'{block_prefix}.{idx}.attn.qkv.lora_q.A', etc. | |
This dict is modified in-place to remove LoRA weights after merging. | |
alpha : float, optional | |
Scaling factor for the LoRA update. Defaults to 1.0. | |
block_prefix : str, optional | |
Prefix to locate transformer blocks in the model. Defaults to "encoder.vit.blocks". | |
Returns: | |
-------- | |
dict | |
The modified state_dict with LoRA weights removed after merging. | |
""" | |
with torch.no_grad(): | |
for idx in range(len(model.encoder.vit.blocks)): | |
prefix = f"{block_prefix}.{idx}.attn.qkv" | |
# Extract LoRA matrices | |
A_q = state_dict.pop(f"{prefix}.lora_q.A") | |
B_q = state_dict.pop(f"{prefix}.lora_q.B") | |
A_v = state_dict.pop(f"{prefix}.lora_v.A") | |
B_v = state_dict.pop(f"{prefix}.lora_v.B") | |
# Compute low-rank updates (transposed to match weight shape) | |
delta_q = (alpha * A_q @ B_q).T | |
delta_v = (alpha * A_v @ B_v).T | |
# Get original QKV weight matrix (shape: [3*dim, dim]) | |
W = model.get_parameter(f"{prefix}.weight") | |
dim = delta_q.shape[0] | |
assert W.shape[0] == 3 * dim, f"Unexpected QKV shape: {W.shape}" | |
# Apply LoRA deltas to Q and V projections | |
W[:dim, :] += delta_q # Q projection | |
W[2 * dim:, :] += delta_v # V projection | |
return state_dict | |
def get_hoptimus0_hf(repo_id): | |
""" Hoptimus foundation model from hugginface repo id | |
""" | |
model = timm.create_model( | |
"vit_giant_patch14_reg4_dinov2", img_size=224, | |
drop_path_rate=0., num_classes=0, | |
global_pool="", pretrained=False, init_values=1e-5, | |
dynamic_img_size=False) | |
state_dict = load_state_dict_from_hf(repo_id, weights_only=True) | |
model.load_state_dict(state_dict) | |
return model | |
def validate_load_info(load_info): | |
""" | |
Validates the result of model.load_state_dict(..., strict=False). | |
Raises: | |
ValueError if unexpected keys are found, | |
or if missing keys are not related to the allowed encoder modules. | |
""" | |
# 1. Raise if any unexpected keys | |
if load_info.unexpected_keys: | |
raise ValueError(f"Unexpected keys in state_dict: {load_info.unexpected_keys}") | |
# 2. Raise if any missing keys are not part of allowed encoder modules | |
for key in load_info.missing_keys: | |
if ".lora" in key: | |
raise ValueError(f"Missing LoRA checkpoint in state_dict: {key}") | |
elif not any(part in key for part in ["encoder.vit.", "encoder.model."]): | |
raise ValueError(f"Missing key in state_dict: {key}") | |