U1020040
commited on
Commit
·
2dc4da8
1
Parent(s):
4c1ef27
first commit
Browse files- .gitignore +3 -0
- config.yaml +63 -0
- config_hf.json +11 -0
- logo.svg +232 -0
- logreg.pth +3 -0
- model.py +410 -0
- model.safetensors +3 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.py[cod]
|
3 |
+
*.so
|
config.yaml
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data:
|
2 |
+
slide_dataframe_path: /home/U1020040/data/slide_dataframe.csv
|
3 |
+
train_dataframe_path: /home/U1020040/data/train_dataframe.csv
|
4 |
+
val_dataframe_path: /home/U1020040/data/val_dataframe.csv
|
5 |
+
test_dataframe_path: /home/U1020040/data/test_dataframe.csv
|
6 |
+
augmentation_dir: /root/workdir/tile_orion_norm_slides
|
7 |
+
channel_stats_path: channel_stats.json
|
8 |
+
targ_channel_names:
|
9 |
+
- Hoechst
|
10 |
+
- CD31
|
11 |
+
- CD45
|
12 |
+
- CD68
|
13 |
+
- CD4
|
14 |
+
- FOXP3
|
15 |
+
- CD8a
|
16 |
+
- CD45RO
|
17 |
+
- CD20
|
18 |
+
- PD-L1
|
19 |
+
- CD3e
|
20 |
+
- CD163
|
21 |
+
- E-cadherin
|
22 |
+
- Ki67
|
23 |
+
- Pan-CK
|
24 |
+
- SMA
|
25 |
+
train:
|
26 |
+
epochs: 15
|
27 |
+
batch_size: 16
|
28 |
+
gan_train: false
|
29 |
+
gan_mode: structural
|
30 |
+
learning_rate_d: 0.0002
|
31 |
+
learning_rate_g: 0.0002
|
32 |
+
precision: 16-mixed
|
33 |
+
foreground_head: false
|
34 |
+
use_cell_metrics: true
|
35 |
+
wandb_project: he-if-image-to-image
|
36 |
+
wandb_note: model_vitmatte
|
37 |
+
losses:
|
38 |
+
lambda_factor: 50
|
39 |
+
use_weighted_mae: false
|
40 |
+
adversarial_loss: binary_crossentropy
|
41 |
+
perceptual_loss: false
|
42 |
+
cell_loss:
|
43 |
+
use_loss: false
|
44 |
+
use_mse: false
|
45 |
+
use_clustering: false
|
46 |
+
mlp_path: mlp.ckpt
|
47 |
+
callbacks:
|
48 |
+
modelcheckpoint:
|
49 |
+
mode: max
|
50 |
+
monitor: val_cell_auc
|
51 |
+
data_sampler:
|
52 |
+
use_sampler: false
|
53 |
+
mode: cell
|
54 |
+
tresh: 4
|
55 |
+
other_percent: 0.2
|
56 |
+
model:
|
57 |
+
model_name: myvitmatte
|
58 |
+
dropout: 0.1
|
59 |
+
foreground_head: false
|
60 |
+
checkpoint_path: null
|
61 |
+
encoder:
|
62 |
+
encoder_name: hoptimus0
|
63 |
+
encoder_weights: /root/workdir/foundation_models/hoptimus0.bin
|
config_hf.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"img_size": 256,
|
3 |
+
"targ_channel_names": [
|
4 |
+
"Hoechst", "CD31", "CD45", "CD68", "CD4", "FOXP3", "CD8a",
|
5 |
+
"CD45RO", "CD20", "PD-L1", "CD3e", "CD163", "E-cadherin",
|
6 |
+
"Ki67", "Pan-CK", "SMA"
|
7 |
+
],
|
8 |
+
"use_attention": true,
|
9 |
+
"hoptimus_hf_id": "bioptimus/H-optimus-0",
|
10 |
+
"license": "Sanofi Custom CC BY-NCC"
|
11 |
+
}
|
logo.svg
ADDED
|
logreg.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba8cdc7ec1aa41f017ecb159771c1ab412b399cab0a56266a2169fad3c0c2aea
|
3 |
+
size 2781
|
model.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This script defines the MIPHEI-ViT architecture for image-to-image translation
|
3 |
+
Some modules in this file are adapted from: https://github.com/hustvl/ViTMatte/
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import timm
|
10 |
+
from timm.models import VisionTransformer, SwinTransformer
|
11 |
+
from timm.models import load_state_dict_from_hf
|
12 |
+
|
13 |
+
|
14 |
+
class Basic_Conv3x3(nn.Module):
|
15 |
+
"""
|
16 |
+
Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
|
17 |
+
https://github.com/hustvl/ViTMatte/blob/main/modeling/decoder/detail_capture.py#L5
|
18 |
+
"""
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
in_chans,
|
22 |
+
out_chans,
|
23 |
+
stride=2,
|
24 |
+
padding=1,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
|
28 |
+
self.bn = nn.BatchNorm2d(out_chans)
|
29 |
+
self.relu = nn.ReLU(inplace=False)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = self.conv(x)
|
33 |
+
x = self.bn(x)
|
34 |
+
x = self.relu(x)
|
35 |
+
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class ConvStream(nn.Module):
|
40 |
+
"""
|
41 |
+
Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
|
42 |
+
"""
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
in_chans = 4,
|
46 |
+
out_chans = [48, 96, 192],
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
self.convs = nn.ModuleList()
|
50 |
+
|
51 |
+
self.conv_chans = out_chans.copy()
|
52 |
+
self.conv_chans.insert(0, in_chans)
|
53 |
+
|
54 |
+
for i in range(len(self.conv_chans)-1):
|
55 |
+
in_chan_ = self.conv_chans[i]
|
56 |
+
out_chan_ = self.conv_chans[i+1]
|
57 |
+
self.convs.append(
|
58 |
+
Basic_Conv3x3(in_chan_, out_chan_)
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
out_dict = {'D0': x}
|
63 |
+
for i in range(len(self.convs)):
|
64 |
+
x = self.convs[i](x)
|
65 |
+
name_ = 'D'+str(i+1)
|
66 |
+
out_dict[name_] = x
|
67 |
+
|
68 |
+
return out_dict
|
69 |
+
|
70 |
+
|
71 |
+
class SegmentationHead(nn.Sequential):
|
72 |
+
# https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/base/heads.py#L5
|
73 |
+
def __init__(
|
74 |
+
self, in_channels, out_channels, kernel_size=3, activation=None, use_attention=False,
|
75 |
+
):
|
76 |
+
if use_attention:
|
77 |
+
attention = AttentionBlock(in_channels)
|
78 |
+
else:
|
79 |
+
attention = nn.Identity()
|
80 |
+
conv2d = nn.Conv2d(
|
81 |
+
in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
|
82 |
+
)
|
83 |
+
activation = activation
|
84 |
+
super().__init__(attention, conv2d, activation)
|
85 |
+
|
86 |
+
|
87 |
+
class AttentionBlock(nn.Module):
|
88 |
+
"""
|
89 |
+
Attention gate
|
90 |
+
|
91 |
+
Parameters:
|
92 |
+
-----------
|
93 |
+
in_chns : int
|
94 |
+
Number of input channels.
|
95 |
+
|
96 |
+
Forward Input:
|
97 |
+
--------------
|
98 |
+
x : torch.Tensor
|
99 |
+
Input tensor of shape [B, C, H, W].
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
--------
|
103 |
+
torch.Tensor
|
104 |
+
Reweighted tensor of the same shape as input.
|
105 |
+
"""
|
106 |
+
def __init__(self, in_chns):
|
107 |
+
super(AttentionBlock, self).__init__()
|
108 |
+
# Attention generation
|
109 |
+
self.psi = nn.Sequential(
|
110 |
+
nn.Conv2d(in_chns, in_chns // 2, kernel_size=1, stride=1, padding=0, bias=True),
|
111 |
+
nn.BatchNorm2d(in_chns // 2),
|
112 |
+
nn.ReLU(),
|
113 |
+
nn.Conv2d(in_chns // 2, 1, kernel_size=1, stride=1, padding=0, bias=True),
|
114 |
+
nn.Sigmoid()
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
# Project decoder output to intermediate space
|
119 |
+
g = self.psi(x)
|
120 |
+
return x * g
|
121 |
+
|
122 |
+
|
123 |
+
class Fusion_Block(nn.Module):
|
124 |
+
"""
|
125 |
+
Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer.
|
126 |
+
"""
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
in_chans,
|
130 |
+
out_chans,
|
131 |
+
):
|
132 |
+
super().__init__()
|
133 |
+
self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1)
|
134 |
+
|
135 |
+
def forward(self, x, D):
|
136 |
+
F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) ## Nearest ?
|
137 |
+
out = torch.cat([D, F_up], dim=1)
|
138 |
+
out = self.conv(out)
|
139 |
+
|
140 |
+
return out
|
141 |
+
|
142 |
+
|
143 |
+
class MIPHEIViT(nn.Module):
|
144 |
+
"""
|
145 |
+
U-Net-style architecture inspired by ViTMatte, using a Vision Transformer (ViT or Swin)
|
146 |
+
as encoder and a convolutional decoder. Designed for dense image prediction tasks,
|
147 |
+
such as image-to-image translation.
|
148 |
+
|
149 |
+
Parameters:
|
150 |
+
-----------
|
151 |
+
encoder : nn.Module
|
152 |
+
A ViT- or Swin-based encoder that outputs spatial feature maps.
|
153 |
+
decoder : nn.Module
|
154 |
+
A decoder module that maps encoder features (and optionally the original image)
|
155 |
+
to the output prediction.
|
156 |
+
|
157 |
+
Example:
|
158 |
+
--------
|
159 |
+
model = MIPHEIViT(encoder=Encoder(vit), decoder=UNetDecoder())
|
160 |
+
output = model(input_tensor)
|
161 |
+
"""
|
162 |
+
def __init__(self,
|
163 |
+
encoder,
|
164 |
+
decoder,
|
165 |
+
):
|
166 |
+
super(MIPHEIViT, self).__init__()
|
167 |
+
self.encoder = encoder
|
168 |
+
self.decoder = decoder
|
169 |
+
self.initialize()
|
170 |
+
|
171 |
+
def forward(self, x):
|
172 |
+
|
173 |
+
features = self.encoder(x)
|
174 |
+
outputs = self.decoder(features, x)
|
175 |
+
return outputs
|
176 |
+
|
177 |
+
def initialize(self):
|
178 |
+
pass
|
179 |
+
|
180 |
+
@classmethod
|
181 |
+
def from_pretrained_hf(cls, repo_path=None, repo_id=None):
|
182 |
+
from safetensors.torch import load_file
|
183 |
+
import json
|
184 |
+
if repo_path:
|
185 |
+
weights_path = os.path.join(repo_path, "model.safetensors")
|
186 |
+
config_path = os.path.join(repo_path, "config_hf.json")
|
187 |
+
else:
|
188 |
+
from huggingface_hub import hf_hub_download
|
189 |
+
weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
|
190 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json")
|
191 |
+
|
192 |
+
# Load config values
|
193 |
+
with open(config_path, "r") as f:
|
194 |
+
config = json.load(f)
|
195 |
+
|
196 |
+
img_size = config["img_size"]
|
197 |
+
nc_out = len(config["targ_channel_names"])
|
198 |
+
use_attention = config["use_attention"]
|
199 |
+
hoptimus_hf_id = config["hoptimus_hf_id"]
|
200 |
+
|
201 |
+
vit = get_hoptimus0_hf(hoptimus_hf_id)
|
202 |
+
vit.set_input_size(img_size=(img_size, img_size))
|
203 |
+
encoder = Encoder(vit)
|
204 |
+
decoder = Detail_Capture(emb_chans=encoder.embed_dim, out_chans=nc_out, use_attention=use_attention, activation=nn.Tanh())
|
205 |
+
model = cls(encoder=encoder, decoder=decoder)
|
206 |
+
state_dict = load_file(weights_path)
|
207 |
+
state_dict = merge_lora_weights(model, state_dict)
|
208 |
+
load_info = model.load_state_dict(state_dict, strict=False)
|
209 |
+
validate_load_info(load_info)
|
210 |
+
model.eval()
|
211 |
+
return model
|
212 |
+
|
213 |
+
def set_input_size(self, img_size):
|
214 |
+
if any((s & (s - 1)) != 0 or s == 0 for s in img_size):
|
215 |
+
raise ValueError("Both height and width in img_size must be powers of 2")
|
216 |
+
if any(s < 128 for s in img_size):
|
217 |
+
raise ValueError("Height and width must be greater or equal to 128")
|
218 |
+
self.encoder.vit.set_input_size(img_size=img_size)
|
219 |
+
self.encoder.grid_size = self.encoder.vit.patch_embed.grid_size
|
220 |
+
|
221 |
+
|
222 |
+
class Encoder(nn.Module):
|
223 |
+
"""
|
224 |
+
Wraps a Vision Transformer (ViT or Swin) to produce feature maps compatible
|
225 |
+
with U-Net-like architectures. It reshapes and resizes transformer outputs
|
226 |
+
into spatial feature maps.
|
227 |
+
|
228 |
+
Parameters:
|
229 |
+
-----------
|
230 |
+
vit : VisionTransformer or SwinTransformer
|
231 |
+
A pretrained transformer model from `timm` that outputs patch embeddings.
|
232 |
+
"""
|
233 |
+
def __init__(self, vit):
|
234 |
+
super().__init__()
|
235 |
+
if not isinstance(vit, (VisionTransformer, SwinTransformer)):
|
236 |
+
raise ValueError(f"Expected a VisionTransformer or SwinTransformer, got {type(vit)}")
|
237 |
+
self.vit = vit
|
238 |
+
|
239 |
+
self.is_swint = isinstance(vit, SwinTransformer)
|
240 |
+
self.grid_size = self.vit.patch_embed.grid_size
|
241 |
+
if self.is_swint:
|
242 |
+
self.num_prefix_tokens = 0
|
243 |
+
self.embed_dim = self.vit.embed_dim * 2 ** (self.vit.num_layers -1)
|
244 |
+
else:
|
245 |
+
self.num_prefix_tokens = self.vit.num_prefix_tokens
|
246 |
+
self.embed_dim = self.vit.embed_dim
|
247 |
+
patch_size = self.vit.patch_embed.patch_size
|
248 |
+
img_size = self.vit.patch_embed.img_size
|
249 |
+
assert img_size[0] % 16 == 0
|
250 |
+
assert img_size[1] % 16 == 0
|
251 |
+
|
252 |
+
if self.is_swint:
|
253 |
+
self.scale_factor = (2., 2.)
|
254 |
+
else:
|
255 |
+
if patch_size != (16, 16):
|
256 |
+
target_grid_size = (img_size[0] / 16, img_size[1] / 16)
|
257 |
+
self.scale_factor = (target_grid_size[0] / self.grid_size[0], target_grid_size[1] / self.grid_size[1])
|
258 |
+
else:
|
259 |
+
self.scale_factor = None
|
260 |
+
|
261 |
+
def forward(self, x):
|
262 |
+
features = self.vit(x)
|
263 |
+
if self.is_swint:
|
264 |
+
features = features.permute(0, 3, 1, 2)
|
265 |
+
else:
|
266 |
+
features = features[:, self.num_prefix_tokens:]
|
267 |
+
features = features.permute(0, 2, 1)
|
268 |
+
features = features.view((-1, self.embed_dim, *self.grid_size))
|
269 |
+
if self.scale_factor is not None:
|
270 |
+
features = F.interpolate(features, scale_factor=self.scale_factor, mode="bicubic")
|
271 |
+
return features
|
272 |
+
|
273 |
+
|
274 |
+
class Detail_Capture(nn.Module):
|
275 |
+
"""
|
276 |
+
Simple and Lightweight Detail Capture Module for ViT Matting.
|
277 |
+
"""
|
278 |
+
def __init__(
|
279 |
+
self,
|
280 |
+
emb_chans,
|
281 |
+
in_chans=3,
|
282 |
+
out_chans=1,
|
283 |
+
convstream_out = [48, 96, 192],
|
284 |
+
fusion_out = [256, 128, 64, 32],
|
285 |
+
use_attention=True,
|
286 |
+
activation=torch.nn.Identity()
|
287 |
+
):
|
288 |
+
super().__init__()
|
289 |
+
assert len(fusion_out) == len(convstream_out) + 1
|
290 |
+
|
291 |
+
self.convstream = ConvStream(in_chans=in_chans)
|
292 |
+
self.conv_chans = self.convstream.conv_chans
|
293 |
+
self.num_heads = out_chans
|
294 |
+
|
295 |
+
self.fusion_blks = nn.ModuleList()
|
296 |
+
self.fus_channs = fusion_out.copy()
|
297 |
+
self.fus_channs.insert(0, emb_chans)
|
298 |
+
for i in range(len(self.fus_channs)-1):
|
299 |
+
self.fusion_blks.append(
|
300 |
+
Fusion_Block(
|
301 |
+
in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)],
|
302 |
+
out_chans = self.fus_channs[i+1],
|
303 |
+
)
|
304 |
+
)
|
305 |
+
|
306 |
+
for idx in range(self.num_heads):
|
307 |
+
setattr(self, f'segmentation_head_{idx}', SegmentationHead(
|
308 |
+
in_channels=fusion_out[-1],
|
309 |
+
out_channels=1,
|
310 |
+
activation=activation,
|
311 |
+
kernel_size=3,
|
312 |
+
use_attention=use_attention
|
313 |
+
))
|
314 |
+
|
315 |
+
def forward(self, features, images):
|
316 |
+
detail_features = self.convstream(images)
|
317 |
+
for i in range(len(self.fusion_blks)):
|
318 |
+
d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
|
319 |
+
features = self.fusion_blks[i](features, detail_features[d_name_])
|
320 |
+
|
321 |
+
outputs = []
|
322 |
+
for idx_head in range(self.num_heads):
|
323 |
+
segmentation_head = getattr(self, f'segmentation_head_{idx_head}')
|
324 |
+
output = segmentation_head(features)
|
325 |
+
outputs.append(output)
|
326 |
+
outputs = torch.cat(outputs, dim=1)
|
327 |
+
|
328 |
+
return outputs
|
329 |
+
|
330 |
+
|
331 |
+
def merge_lora_weights(model, state_dict, alpha=1.0, block_prefix="encoder.vit.blocks"):
|
332 |
+
"""
|
333 |
+
Merges LoRA weights into the base attention Q and V projection weights for each transformer block.
|
334 |
+
We keep LoRA weights in the model.safetensors to avoid having the original foundation model weights in the repo.
|
335 |
+
|
336 |
+
Parameters:
|
337 |
+
-----------
|
338 |
+
model : torch.nn.Module
|
339 |
+
The model containing the transformer blocks to modify (e.g., ViT backbone).
|
340 |
+
state_dict : dict
|
341 |
+
The state_dict containing LoRA matrices with keys formatted as
|
342 |
+
'{block_prefix}.{idx}.attn.qkv.lora_q.A', etc.
|
343 |
+
This dict is modified in-place to remove LoRA weights after merging.
|
344 |
+
alpha : float, optional
|
345 |
+
Scaling factor for the LoRA update. Defaults to 1.0.
|
346 |
+
block_prefix : str, optional
|
347 |
+
Prefix to locate transformer blocks in the model. Defaults to "encoder.vit.blocks".
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
--------
|
351 |
+
dict
|
352 |
+
The modified state_dict with LoRA weights removed after merging.
|
353 |
+
"""
|
354 |
+
with torch.no_grad():
|
355 |
+
for idx in range(len(model.encoder.vit.blocks)):
|
356 |
+
prefix = f"{block_prefix}.{idx}.attn.qkv"
|
357 |
+
|
358 |
+
# Extract LoRA matrices
|
359 |
+
A_q = state_dict.pop(f"{prefix}.lora_q.A")
|
360 |
+
B_q = state_dict.pop(f"{prefix}.lora_q.B")
|
361 |
+
A_v = state_dict.pop(f"{prefix}.lora_v.A")
|
362 |
+
B_v = state_dict.pop(f"{prefix}.lora_v.B")
|
363 |
+
|
364 |
+
# Compute low-rank updates (transposed to match weight shape)
|
365 |
+
delta_q = (alpha * A_q @ B_q).T
|
366 |
+
delta_v = (alpha * A_v @ B_v).T
|
367 |
+
|
368 |
+
# Get original QKV weight matrix (shape: [3*dim, dim])
|
369 |
+
W = model.get_parameter(f"{prefix}.weight")
|
370 |
+
dim = delta_q.shape[0]
|
371 |
+
assert W.shape[0] == 3 * dim, f"Unexpected QKV shape: {W.shape}"
|
372 |
+
|
373 |
+
# Apply LoRA deltas to Q and V projections
|
374 |
+
W[:dim, :] += delta_q # Q projection
|
375 |
+
W[2 * dim:, :] += delta_v # V projection
|
376 |
+
|
377 |
+
return state_dict
|
378 |
+
|
379 |
+
|
380 |
+
def get_hoptimus0_hf(repo_id):
|
381 |
+
""" Hoptimus foundation model from hugginface repo id
|
382 |
+
"""
|
383 |
+
model = timm.create_model(
|
384 |
+
"vit_giant_patch14_reg4_dinov2", img_size=224,
|
385 |
+
drop_path_rate=0., num_classes=0,
|
386 |
+
global_pool="", pretrained=False, init_values=1e-5,
|
387 |
+
dynamic_img_size=False)
|
388 |
+
state_dict = load_state_dict_from_hf(repo_id, weights_only=True)
|
389 |
+
model.load_state_dict(state_dict)
|
390 |
+
return model
|
391 |
+
|
392 |
+
|
393 |
+
def validate_load_info(load_info):
|
394 |
+
"""
|
395 |
+
Validates the result of model.load_state_dict(..., strict=False).
|
396 |
+
|
397 |
+
Raises:
|
398 |
+
ValueError if unexpected keys are found,
|
399 |
+
or if missing keys are not related to the allowed encoder modules.
|
400 |
+
"""
|
401 |
+
# 1. Raise if any unexpected keys
|
402 |
+
if load_info.unexpected_keys:
|
403 |
+
raise ValueError(f"Unexpected keys in state_dict: {load_info.unexpected_keys}")
|
404 |
+
|
405 |
+
# 2. Raise if any missing keys are not part of allowed encoder modules
|
406 |
+
for key in load_info.missing_keys:
|
407 |
+
if ".lora" in key:
|
408 |
+
raise ValueError(f"Missing LoRA checkpoint in state_dict: {key}")
|
409 |
+
elif not any(part in key for part in ["encoder.vit.", "encoder.model."]):
|
410 |
+
raise ValueError(f"Missing key in state_dict: {key}")
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d426fb3ad3635413ca93de3cc41529a191f70e6930fc5074e66a3da0d85fe43
|
3 |
+
size 26840896
|