Spaces:
Running
Running
U1020040
commited on
Commit
·
3047e70
1
Parent(s):
fbbdc88
first commit
Browse files- .gitignore +3 -0
- LICENSE +9 -0
- app.py +59 -0
- logo.svg +232 -0
- model.py +410 -0
- requirements.txt +7 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.py[cod]
|
3 |
+
*.so
|
LICENSE
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Non Commercial License Notice:
|
2 |
+
|
3 |
+
Copyright (c) 2025 Sanofi.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, for academic research purposes only and for non-commercial uses only, to any person from academic research or non-profit organizations obtaining a copy of this software and associated documentation files (the "Software"), to use, copy, modify, or merge the Software, subject to the following conditions: this permission notice shall be included in all copies of the Software or of substantial portions of the Software.
|
6 |
+
|
7 |
+
For purposes of this license, “non-commercial use” excludes uses foreseeably resulting in a commercial benefit. To use this software for other purposes (such as the development of a commercial product, including but not limited to software, service, or pharmaceuticals, or in a collaboration with a private company), please contact SANOFI at [email protected].
|
8 |
+
|
9 |
+
All other rights are reserved, including those for text and data mining, AI training and similar technologies. The Software is provided “as is”, without warranty of any kind, express or implied, including the warranties of noninfringement.
|
app.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
import torch
|
4 |
+
import json
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
from model import MIPHEIViT
|
8 |
+
|
9 |
+
# Load model once
|
10 |
+
repo_id = "Estabousi/MIPHEI-vit"
|
11 |
+
model = MIPHEIViT.from_pretrained_hf(repo_id=repo_id)
|
12 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json")
|
13 |
+
|
14 |
+
model.eval()
|
15 |
+
mean = torch.Tensor([0.485, 0.456, 0.406]).to(torch.float32).reshape((-1, 1, 1))
|
16 |
+
std = torch.Tensor([0.229, 0.224, 0.225]).to(torch.float32).reshape((-1, 1, 1))
|
17 |
+
with open(config_path, "r") as f:
|
18 |
+
config = json.load(f)
|
19 |
+
channel_names = config["targ_channel_names"]
|
20 |
+
|
21 |
+
def preprocess(image):
|
22 |
+
image = image.convert("RGB").resize((256, 256))
|
23 |
+
tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255
|
24 |
+
tensor = (tensor - mean) / std
|
25 |
+
return tensor.unsqueeze(0) # [1, 3, H, W]
|
26 |
+
|
27 |
+
def predict(image):
|
28 |
+
input_tensor = preprocess(image)
|
29 |
+
with torch.inference_mode():
|
30 |
+
output = model(input_tensor)[0] # [16, H, W]
|
31 |
+
output = (output.clamp(-0.9, 0.9) + 0.9) / 1.8
|
32 |
+
output = np.uint8(output.cpu().numpy() * 255)
|
33 |
+
|
34 |
+
# Convert each mIF channel to grayscale PIL image
|
35 |
+
channel_imgs = []
|
36 |
+
for i in range(output.shape[0]):
|
37 |
+
ch_img = output[i]
|
38 |
+
pil_ch = Image.fromarray(ch_img, mode='L')
|
39 |
+
channel_imgs.append(pil_ch)
|
40 |
+
|
41 |
+
# Return list: input + 16 channels
|
42 |
+
return channel_imgs
|
43 |
+
|
44 |
+
# Prepare Gradio UI
|
45 |
+
demo = gr.Interface(
|
46 |
+
fn=predict,
|
47 |
+
inputs=gr.Image(type="pil", label="Input H&E"),
|
48 |
+
outputs=[gr.Image(type="pil", label=f"mIF Channel {channel_names[i]}") for i in range(16)],
|
49 |
+
title="MIPHEI-ViT: Full mIF Prediction",
|
50 |
+
description=(
|
51 |
+
"Upload an H&E image tile (colorectal 256×256 pxs at 0.5 µm/px recommended). "
|
52 |
+
"The image will be resized to (256×256) if needed.\n"
|
53 |
+
"The model predicts 16-channel multiplex immunofluorescence, "
|
54 |
+
"with each marker shown as a grayscale image."
|
55 |
+
)
|
56 |
+
)
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
demo.launch()
|
logo.svg
ADDED
|
model.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script defines the MIPHEI-ViT architecture for image-to-image translation
|
3 |
+
Some modules in this file are adapted from: https://github.com/hustvl/ViTMatte/
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import timm
|
10 |
+
from timm.models import VisionTransformer, SwinTransformer
|
11 |
+
from timm.models import load_state_dict_from_hf
|
12 |
+
|
13 |
+
|
14 |
+
class Basic_Conv3x3(nn.Module):
|
15 |
+
"""
|
16 |
+
Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
|
17 |
+
https://github.com/hustvl/ViTMatte/blob/main/modeling/decoder/detail_capture.py#L5
|
18 |
+
"""
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
in_chans,
|
22 |
+
out_chans,
|
23 |
+
stride=2,
|
24 |
+
padding=1,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
|
28 |
+
self.bn = nn.BatchNorm2d(out_chans)
|
29 |
+
self.relu = nn.ReLU(inplace=False)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = self.conv(x)
|
33 |
+
x = self.bn(x)
|
34 |
+
x = self.relu(x)
|
35 |
+
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class ConvStream(nn.Module):
|
40 |
+
"""
|
41 |
+
Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
|
42 |
+
"""
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
in_chans = 4,
|
46 |
+
out_chans = [48, 96, 192],
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
self.convs = nn.ModuleList()
|
50 |
+
|
51 |
+
self.conv_chans = out_chans.copy()
|
52 |
+
self.conv_chans.insert(0, in_chans)
|
53 |
+
|
54 |
+
for i in range(len(self.conv_chans)-1):
|
55 |
+
in_chan_ = self.conv_chans[i]
|
56 |
+
out_chan_ = self.conv_chans[i+1]
|
57 |
+
self.convs.append(
|
58 |
+
Basic_Conv3x3(in_chan_, out_chan_)
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
out_dict = {'D0': x}
|
63 |
+
for i in range(len(self.convs)):
|
64 |
+
x = self.convs[i](x)
|
65 |
+
name_ = 'D'+str(i+1)
|
66 |
+
out_dict[name_] = x
|
67 |
+
|
68 |
+
return out_dict
|
69 |
+
|
70 |
+
|
71 |
+
class SegmentationHead(nn.Sequential):
|
72 |
+
# https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/base/heads.py#L5
|
73 |
+
def __init__(
|
74 |
+
self, in_channels, out_channels, kernel_size=3, activation=None, use_attention=False,
|
75 |
+
):
|
76 |
+
if use_attention:
|
77 |
+
attention = AttentionBlock(in_channels)
|
78 |
+
else:
|
79 |
+
attention = nn.Identity()
|
80 |
+
conv2d = nn.Conv2d(
|
81 |
+
in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
|
82 |
+
)
|
83 |
+
activation = activation
|
84 |
+
super().__init__(attention, conv2d, activation)
|
85 |
+
|
86 |
+
|
87 |
+
class AttentionBlock(nn.Module):
|
88 |
+
"""
|
89 |
+
Attention gate
|
90 |
+
|
91 |
+
Parameters:
|
92 |
+
-----------
|
93 |
+
in_chns : int
|
94 |
+
Number of input channels.
|
95 |
+
|
96 |
+
Forward Input:
|
97 |
+
--------------
|
98 |
+
x : torch.Tensor
|
99 |
+
Input tensor of shape [B, C, H, W].
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
--------
|
103 |
+
torch.Tensor
|
104 |
+
Reweighted tensor of the same shape as input.
|
105 |
+
"""
|
106 |
+
def __init__(self, in_chns):
|
107 |
+
super(AttentionBlock, self).__init__()
|
108 |
+
# Attention generation
|
109 |
+
self.psi = nn.Sequential(
|
110 |
+
nn.Conv2d(in_chns, in_chns // 2, kernel_size=1, stride=1, padding=0, bias=True),
|
111 |
+
nn.BatchNorm2d(in_chns // 2),
|
112 |
+
nn.ReLU(),
|
113 |
+
nn.Conv2d(in_chns // 2, 1, kernel_size=1, stride=1, padding=0, bias=True),
|
114 |
+
nn.Sigmoid()
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
# Project decoder output to intermediate space
|
119 |
+
g = self.psi(x)
|
120 |
+
return x * g
|
121 |
+
|
122 |
+
|
123 |
+
class Fusion_Block(nn.Module):
|
124 |
+
"""
|
125 |
+
Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer.
|
126 |
+
"""
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
in_chans,
|
130 |
+
out_chans,
|
131 |
+
):
|
132 |
+
super().__init__()
|
133 |
+
self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1)
|
134 |
+
|
135 |
+
def forward(self, x, D):
|
136 |
+
F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) ## Nearest ?
|
137 |
+
out = torch.cat([D, F_up], dim=1)
|
138 |
+
out = self.conv(out)
|
139 |
+
|
140 |
+
return out
|
141 |
+
|
142 |
+
|
143 |
+
class MIPHEIViT(nn.Module):
|
144 |
+
"""
|
145 |
+
U-Net-style architecture inspired by ViTMatte, using a Vision Transformer (ViT or Swin)
|
146 |
+
as encoder and a convolutional decoder. Designed for dense image prediction tasks,
|
147 |
+
such as image-to-image translation.
|
148 |
+
|
149 |
+
Parameters:
|
150 |
+
-----------
|
151 |
+
encoder : nn.Module
|
152 |
+
A ViT- or Swin-based encoder that outputs spatial feature maps.
|
153 |
+
decoder : nn.Module
|
154 |
+
A decoder module that maps encoder features (and optionally the original image)
|
155 |
+
to the output prediction.
|
156 |
+
|
157 |
+
Example:
|
158 |
+
--------
|
159 |
+
model = MIPHEIViT(encoder=Encoder(vit), decoder=UNetDecoder())
|
160 |
+
output = model(input_tensor)
|
161 |
+
"""
|
162 |
+
def __init__(self,
|
163 |
+
encoder,
|
164 |
+
decoder,
|
165 |
+
):
|
166 |
+
super(MIPHEIViT, self).__init__()
|
167 |
+
self.encoder = encoder
|
168 |
+
self.decoder = decoder
|
169 |
+
self.initialize()
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
|
173 |
+
features = self.encoder(x)
|
174 |
+
outputs = self.decoder(features, x)
|
175 |
+
return outputs
|
176 |
+
|
177 |
+
def initialize(self):
|
178 |
+
pass
|
179 |
+
|
180 |
+
@classmethod
|
181 |
+
def from_pretrained_hf(cls, repo_path=None, repo_id=None):
|
182 |
+
from safetensors.torch import load_file
|
183 |
+
import json
|
184 |
+
if repo_path:
|
185 |
+
weights_path = os.path.join(repo_path, "model.safetensors")
|
186 |
+
config_path = os.path.join(repo_path, "config_hf.json")
|
187 |
+
else:
|
188 |
+
from huggingface_hub import hf_hub_download
|
189 |
+
weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
|
190 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json")
|
191 |
+
|
192 |
+
# Load config values
|
193 |
+
with open(config_path, "r") as f:
|
194 |
+
config = json.load(f)
|
195 |
+
|
196 |
+
img_size = config["img_size"]
|
197 |
+
nc_out = len(config["targ_channel_names"])
|
198 |
+
use_attention = config["use_attention"]
|
199 |
+
hoptimus_hf_id = config["hoptimus_hf_id"]
|
200 |
+
|
201 |
+
vit = get_hoptimus0_hf(hoptimus_hf_id)
|
202 |
+
vit.set_input_size(img_size=(img_size, img_size))
|
203 |
+
encoder = Encoder(vit)
|
204 |
+
decoder = Detail_Capture(emb_chans=encoder.embed_dim, out_chans=nc_out, use_attention=use_attention, activation=nn.Tanh())
|
205 |
+
model = cls(encoder=encoder, decoder=decoder)
|
206 |
+
state_dict = load_file(weights_path)
|
207 |
+
state_dict = merge_lora_weights(model, state_dict)
|
208 |
+
load_info = model.load_state_dict(state_dict, strict=False)
|
209 |
+
validate_load_info(load_info)
|
210 |
+
model.eval()
|
211 |
+
return model
|
212 |
+
|
213 |
+
def set_input_size(self, img_size):
|
214 |
+
if any((s & (s - 1)) != 0 or s == 0 for s in img_size):
|
215 |
+
raise ValueError("Both height and width in img_size must be powers of 2")
|
216 |
+
if any(s < 128 for s in img_size):
|
217 |
+
raise ValueError("Height and width must be greater or equal to 128")
|
218 |
+
self.encoder.vit.set_input_size(img_size=img_size)
|
219 |
+
self.encoder.grid_size = self.encoder.vit.patch_embed.grid_size
|
220 |
+
|
221 |
+
|
222 |
+
class Encoder(nn.Module):
|
223 |
+
"""
|
224 |
+
Wraps a Vision Transformer (ViT or Swin) to produce feature maps compatible
|
225 |
+
with U-Net-like architectures. It reshapes and resizes transformer outputs
|
226 |
+
into spatial feature maps.
|
227 |
+
|
228 |
+
Parameters:
|
229 |
+
-----------
|
230 |
+
vit : VisionTransformer or SwinTransformer
|
231 |
+
A pretrained transformer model from `timm` that outputs patch embeddings.
|
232 |
+
"""
|
233 |
+
def __init__(self, vit):
|
234 |
+
super().__init__()
|
235 |
+
if not isinstance(vit, (VisionTransformer, SwinTransformer)):
|
236 |
+
raise ValueError(f"Expected a VisionTransformer or SwinTransformer, got {type(vit)}")
|
237 |
+
self.vit = vit
|
238 |
+
|
239 |
+
self.is_swint = isinstance(vit, SwinTransformer)
|
240 |
+
self.grid_size = self.vit.patch_embed.grid_size
|
241 |
+
if self.is_swint:
|
242 |
+
self.num_prefix_tokens = 0
|
243 |
+
self.embed_dim = self.vit.embed_dim * 2 ** (self.vit.num_layers -1)
|
244 |
+
else:
|
245 |
+
self.num_prefix_tokens = self.vit.num_prefix_tokens
|
246 |
+
self.embed_dim = self.vit.embed_dim
|
247 |
+
patch_size = self.vit.patch_embed.patch_size
|
248 |
+
img_size = self.vit.patch_embed.img_size
|
249 |
+
assert img_size[0] % 16 == 0
|
250 |
+
assert img_size[1] % 16 == 0
|
251 |
+
|
252 |
+
if self.is_swint:
|
253 |
+
self.scale_factor = (2., 2.)
|
254 |
+
else:
|
255 |
+
if patch_size != (16, 16):
|
256 |
+
target_grid_size = (img_size[0] / 16, img_size[1] / 16)
|
257 |
+
self.scale_factor = (target_grid_size[0] / self.grid_size[0], target_grid_size[1] / self.grid_size[1])
|
258 |
+
else:
|
259 |
+
self.scale_factor = None
|
260 |
+
|
261 |
+
def forward(self, x):
|
262 |
+
features = self.vit(x)
|
263 |
+
if self.is_swint:
|
264 |
+
features = features.permute(0, 3, 1, 2)
|
265 |
+
else:
|
266 |
+
features = features[:, self.num_prefix_tokens:]
|
267 |
+
features = features.permute(0, 2, 1)
|
268 |
+
features = features.view((-1, self.embed_dim, *self.grid_size))
|
269 |
+
if self.scale_factor is not None:
|
270 |
+
features = F.interpolate(features, scale_factor=self.scale_factor, mode="bicubic")
|
271 |
+
return features
|
272 |
+
|
273 |
+
|
274 |
+
class Detail_Capture(nn.Module):
|
275 |
+
"""
|
276 |
+
Simple and Lightweight Detail Capture Module for ViT Matting.
|
277 |
+
"""
|
278 |
+
def __init__(
|
279 |
+
self,
|
280 |
+
emb_chans,
|
281 |
+
in_chans=3,
|
282 |
+
out_chans=1,
|
283 |
+
convstream_out = [48, 96, 192],
|
284 |
+
fusion_out = [256, 128, 64, 32],
|
285 |
+
use_attention=True,
|
286 |
+
activation=torch.nn.Identity()
|
287 |
+
):
|
288 |
+
super().__init__()
|
289 |
+
assert len(fusion_out) == len(convstream_out) + 1
|
290 |
+
|
291 |
+
self.convstream = ConvStream(in_chans=in_chans)
|
292 |
+
self.conv_chans = self.convstream.conv_chans
|
293 |
+
self.num_heads = out_chans
|
294 |
+
|
295 |
+
self.fusion_blks = nn.ModuleList()
|
296 |
+
self.fus_channs = fusion_out.copy()
|
297 |
+
self.fus_channs.insert(0, emb_chans)
|
298 |
+
for i in range(len(self.fus_channs)-1):
|
299 |
+
self.fusion_blks.append(
|
300 |
+
Fusion_Block(
|
301 |
+
in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)],
|
302 |
+
out_chans = self.fus_channs[i+1],
|
303 |
+
)
|
304 |
+
)
|
305 |
+
|
306 |
+
for idx in range(self.num_heads):
|
307 |
+
setattr(self, f'segmentation_head_{idx}', SegmentationHead(
|
308 |
+
in_channels=fusion_out[-1],
|
309 |
+
out_channels=1,
|
310 |
+
activation=activation,
|
311 |
+
kernel_size=3,
|
312 |
+
use_attention=use_attention
|
313 |
+
))
|
314 |
+
|
315 |
+
def forward(self, features, images):
|
316 |
+
detail_features = self.convstream(images)
|
317 |
+
for i in range(len(self.fusion_blks)):
|
318 |
+
d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
|
319 |
+
features = self.fusion_blks[i](features, detail_features[d_name_])
|
320 |
+
|
321 |
+
outputs = []
|
322 |
+
for idx_head in range(self.num_heads):
|
323 |
+
segmentation_head = getattr(self, f'segmentation_head_{idx_head}')
|
324 |
+
output = segmentation_head(features)
|
325 |
+
outputs.append(output)
|
326 |
+
outputs = torch.cat(outputs, dim=1)
|
327 |
+
|
328 |
+
return outputs
|
329 |
+
|
330 |
+
|
331 |
+
def merge_lora_weights(model, state_dict, alpha=1.0, block_prefix="encoder.vit.blocks"):
|
332 |
+
"""
|
333 |
+
Merges LoRA weights into the base attention Q and V projection weights for each transformer block.
|
334 |
+
We keep LoRA weights in the model.safetensors to avoid having the original foundation model weights in the repo.
|
335 |
+
|
336 |
+
Parameters:
|
337 |
+
-----------
|
338 |
+
model : torch.nn.Module
|
339 |
+
The model containing the transformer blocks to modify (e.g., ViT backbone).
|
340 |
+
state_dict : dict
|
341 |
+
The state_dict containing LoRA matrices with keys formatted as
|
342 |
+
'{block_prefix}.{idx}.attn.qkv.lora_q.A', etc.
|
343 |
+
This dict is modified in-place to remove LoRA weights after merging.
|
344 |
+
alpha : float, optional
|
345 |
+
Scaling factor for the LoRA update. Defaults to 1.0.
|
346 |
+
block_prefix : str, optional
|
347 |
+
Prefix to locate transformer blocks in the model. Defaults to "encoder.vit.blocks".
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
--------
|
351 |
+
dict
|
352 |
+
The modified state_dict with LoRA weights removed after merging.
|
353 |
+
"""
|
354 |
+
with torch.no_grad():
|
355 |
+
for idx in range(len(model.encoder.vit.blocks)):
|
356 |
+
prefix = f"{block_prefix}.{idx}.attn.qkv"
|
357 |
+
|
358 |
+
# Extract LoRA matrices
|
359 |
+
A_q = state_dict.pop(f"{prefix}.lora_q.A")
|
360 |
+
B_q = state_dict.pop(f"{prefix}.lora_q.B")
|
361 |
+
A_v = state_dict.pop(f"{prefix}.lora_v.A")
|
362 |
+
B_v = state_dict.pop(f"{prefix}.lora_v.B")
|
363 |
+
|
364 |
+
# Compute low-rank updates (transposed to match weight shape)
|
365 |
+
delta_q = (alpha * A_q @ B_q).T
|
366 |
+
delta_v = (alpha * A_v @ B_v).T
|
367 |
+
|
368 |
+
# Get original QKV weight matrix (shape: [3*dim, dim])
|
369 |
+
W = model.get_parameter(f"{prefix}.weight")
|
370 |
+
dim = delta_q.shape[0]
|
371 |
+
assert W.shape[0] == 3 * dim, f"Unexpected QKV shape: {W.shape}"
|
372 |
+
|
373 |
+
# Apply LoRA deltas to Q and V projections
|
374 |
+
W[:dim, :] += delta_q # Q projection
|
375 |
+
W[2 * dim:, :] += delta_v # V projection
|
376 |
+
|
377 |
+
return state_dict
|
378 |
+
|
379 |
+
|
380 |
+
def get_hoptimus0_hf(repo_id):
|
381 |
+
""" Hoptimus foundation model from hugginface repo id
|
382 |
+
"""
|
383 |
+
model = timm.create_model(
|
384 |
+
"vit_giant_patch14_reg4_dinov2", img_size=224,
|
385 |
+
drop_path_rate=0., num_classes=0,
|
386 |
+
global_pool="", pretrained=False, init_values=1e-5,
|
387 |
+
dynamic_img_size=False)
|
388 |
+
state_dict = load_state_dict_from_hf(repo_id, weights_only=True)
|
389 |
+
model.load_state_dict(state_dict)
|
390 |
+
return model
|
391 |
+
|
392 |
+
|
393 |
+
def validate_load_info(load_info):
|
394 |
+
"""
|
395 |
+
Validates the result of model.load_state_dict(..., strict=False).
|
396 |
+
|
397 |
+
Raises:
|
398 |
+
ValueError if unexpected keys are found,
|
399 |
+
or if missing keys are not related to the allowed encoder modules.
|
400 |
+
"""
|
401 |
+
# 1. Raise if any unexpected keys
|
402 |
+
if load_info.unexpected_keys:
|
403 |
+
raise ValueError(f"Unexpected keys in state_dict: {load_info.unexpected_keys}")
|
404 |
+
|
405 |
+
# 2. Raise if any missing keys are not part of allowed encoder modules
|
406 |
+
for key in load_info.missing_keys:
|
407 |
+
if ".lora" in key:
|
408 |
+
raise ValueError(f"Missing LoRA checkpoint in state_dict: {key}")
|
409 |
+
elif not any(part in key for part in ["encoder.vit.", "encoder.model."]):
|
410 |
+
raise ValueError(f"Missing key in state_dict: {key}")
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
timm
|
3 |
+
gradio
|
4 |
+
safetensors
|
5 |
+
numpy
|
6 |
+
Pillow
|
7 |
+
huggingface_hub
|