U1020040 commited on
Commit
2dc4da8
·
1 Parent(s): 4c1ef27

first commit

Browse files
Files changed (7) hide show
  1. .gitignore +3 -0
  2. config.yaml +63 -0
  3. config_hf.json +11 -0
  4. logo.svg +232 -0
  5. logreg.pth +3 -0
  6. model.py +410 -0
  7. model.safetensors +3 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.so
config.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ slide_dataframe_path: /home/U1020040/data/slide_dataframe.csv
3
+ train_dataframe_path: /home/U1020040/data/train_dataframe.csv
4
+ val_dataframe_path: /home/U1020040/data/val_dataframe.csv
5
+ test_dataframe_path: /home/U1020040/data/test_dataframe.csv
6
+ augmentation_dir: /root/workdir/tile_orion_norm_slides
7
+ channel_stats_path: channel_stats.json
8
+ targ_channel_names:
9
+ - Hoechst
10
+ - CD31
11
+ - CD45
12
+ - CD68
13
+ - CD4
14
+ - FOXP3
15
+ - CD8a
16
+ - CD45RO
17
+ - CD20
18
+ - PD-L1
19
+ - CD3e
20
+ - CD163
21
+ - E-cadherin
22
+ - Ki67
23
+ - Pan-CK
24
+ - SMA
25
+ train:
26
+ epochs: 15
27
+ batch_size: 16
28
+ gan_train: false
29
+ gan_mode: structural
30
+ learning_rate_d: 0.0002
31
+ learning_rate_g: 0.0002
32
+ precision: 16-mixed
33
+ foreground_head: false
34
+ use_cell_metrics: true
35
+ wandb_project: he-if-image-to-image
36
+ wandb_note: model_vitmatte
37
+ losses:
38
+ lambda_factor: 50
39
+ use_weighted_mae: false
40
+ adversarial_loss: binary_crossentropy
41
+ perceptual_loss: false
42
+ cell_loss:
43
+ use_loss: false
44
+ use_mse: false
45
+ use_clustering: false
46
+ mlp_path: mlp.ckpt
47
+ callbacks:
48
+ modelcheckpoint:
49
+ mode: max
50
+ monitor: val_cell_auc
51
+ data_sampler:
52
+ use_sampler: false
53
+ mode: cell
54
+ tresh: 4
55
+ other_percent: 0.2
56
+ model:
57
+ model_name: myvitmatte
58
+ dropout: 0.1
59
+ foreground_head: false
60
+ checkpoint_path: null
61
+ encoder:
62
+ encoder_name: hoptimus0
63
+ encoder_weights: /root/workdir/foundation_models/hoptimus0.bin
config_hf.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "img_size": 256,
3
+ "targ_channel_names": [
4
+ "Hoechst", "CD31", "CD45", "CD68", "CD4", "FOXP3", "CD8a",
5
+ "CD45RO", "CD20", "PD-L1", "CD3e", "CD163", "E-cadherin",
6
+ "Ki67", "Pan-CK", "SMA"
7
+ ],
8
+ "use_attention": true,
9
+ "hoptimus_hf_id": "bioptimus/H-optimus-0",
10
+ "license": "Sanofi Custom CC BY-NCC"
11
+ }
logo.svg ADDED
logreg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba8cdc7ec1aa41f017ecb159771c1ab412b399cab0a56266a2169fad3c0c2aea
3
+ size 2781
model.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script defines the MIPHEI-ViT architecture for image-to-image translation
3
+ Some modules in this file are adapted from: https://github.com/hustvl/ViTMatte/
4
+ """
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import timm
10
+ from timm.models import VisionTransformer, SwinTransformer
11
+ from timm.models import load_state_dict_from_hf
12
+
13
+
14
+ class Basic_Conv3x3(nn.Module):
15
+ """
16
+ Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
17
+ https://github.com/hustvl/ViTMatte/blob/main/modeling/decoder/detail_capture.py#L5
18
+ """
19
+ def __init__(
20
+ self,
21
+ in_chans,
22
+ out_chans,
23
+ stride=2,
24
+ padding=1,
25
+ ):
26
+ super().__init__()
27
+ self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
28
+ self.bn = nn.BatchNorm2d(out_chans)
29
+ self.relu = nn.ReLU(inplace=False)
30
+
31
+ def forward(self, x):
32
+ x = self.conv(x)
33
+ x = self.bn(x)
34
+ x = self.relu(x)
35
+
36
+ return x
37
+
38
+
39
+ class ConvStream(nn.Module):
40
+ """
41
+ Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
42
+ """
43
+ def __init__(
44
+ self,
45
+ in_chans = 4,
46
+ out_chans = [48, 96, 192],
47
+ ):
48
+ super().__init__()
49
+ self.convs = nn.ModuleList()
50
+
51
+ self.conv_chans = out_chans.copy()
52
+ self.conv_chans.insert(0, in_chans)
53
+
54
+ for i in range(len(self.conv_chans)-1):
55
+ in_chan_ = self.conv_chans[i]
56
+ out_chan_ = self.conv_chans[i+1]
57
+ self.convs.append(
58
+ Basic_Conv3x3(in_chan_, out_chan_)
59
+ )
60
+
61
+ def forward(self, x):
62
+ out_dict = {'D0': x}
63
+ for i in range(len(self.convs)):
64
+ x = self.convs[i](x)
65
+ name_ = 'D'+str(i+1)
66
+ out_dict[name_] = x
67
+
68
+ return out_dict
69
+
70
+
71
+ class SegmentationHead(nn.Sequential):
72
+ # https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/base/heads.py#L5
73
+ def __init__(
74
+ self, in_channels, out_channels, kernel_size=3, activation=None, use_attention=False,
75
+ ):
76
+ if use_attention:
77
+ attention = AttentionBlock(in_channels)
78
+ else:
79
+ attention = nn.Identity()
80
+ conv2d = nn.Conv2d(
81
+ in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
82
+ )
83
+ activation = activation
84
+ super().__init__(attention, conv2d, activation)
85
+
86
+
87
+ class AttentionBlock(nn.Module):
88
+ """
89
+ Attention gate
90
+
91
+ Parameters:
92
+ -----------
93
+ in_chns : int
94
+ Number of input channels.
95
+
96
+ Forward Input:
97
+ --------------
98
+ x : torch.Tensor
99
+ Input tensor of shape [B, C, H, W].
100
+
101
+ Returns:
102
+ --------
103
+ torch.Tensor
104
+ Reweighted tensor of the same shape as input.
105
+ """
106
+ def __init__(self, in_chns):
107
+ super(AttentionBlock, self).__init__()
108
+ # Attention generation
109
+ self.psi = nn.Sequential(
110
+ nn.Conv2d(in_chns, in_chns // 2, kernel_size=1, stride=1, padding=0, bias=True),
111
+ nn.BatchNorm2d(in_chns // 2),
112
+ nn.ReLU(),
113
+ nn.Conv2d(in_chns // 2, 1, kernel_size=1, stride=1, padding=0, bias=True),
114
+ nn.Sigmoid()
115
+ )
116
+
117
+ def forward(self, x):
118
+ # Project decoder output to intermediate space
119
+ g = self.psi(x)
120
+ return x * g
121
+
122
+
123
+ class Fusion_Block(nn.Module):
124
+ """
125
+ Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer.
126
+ """
127
+ def __init__(
128
+ self,
129
+ in_chans,
130
+ out_chans,
131
+ ):
132
+ super().__init__()
133
+ self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1)
134
+
135
+ def forward(self, x, D):
136
+ F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) ## Nearest ?
137
+ out = torch.cat([D, F_up], dim=1)
138
+ out = self.conv(out)
139
+
140
+ return out
141
+
142
+
143
+ class MIPHEIViT(nn.Module):
144
+ """
145
+ U-Net-style architecture inspired by ViTMatte, using a Vision Transformer (ViT or Swin)
146
+ as encoder and a convolutional decoder. Designed for dense image prediction tasks,
147
+ such as image-to-image translation.
148
+
149
+ Parameters:
150
+ -----------
151
+ encoder : nn.Module
152
+ A ViT- or Swin-based encoder that outputs spatial feature maps.
153
+ decoder : nn.Module
154
+ A decoder module that maps encoder features (and optionally the original image)
155
+ to the output prediction.
156
+
157
+ Example:
158
+ --------
159
+ model = MIPHEIViT(encoder=Encoder(vit), decoder=UNetDecoder())
160
+ output = model(input_tensor)
161
+ """
162
+ def __init__(self,
163
+ encoder,
164
+ decoder,
165
+ ):
166
+ super(MIPHEIViT, self).__init__()
167
+ self.encoder = encoder
168
+ self.decoder = decoder
169
+ self.initialize()
170
+
171
+ def forward(self, x):
172
+
173
+ features = self.encoder(x)
174
+ outputs = self.decoder(features, x)
175
+ return outputs
176
+
177
+ def initialize(self):
178
+ pass
179
+
180
+ @classmethod
181
+ def from_pretrained_hf(cls, repo_path=None, repo_id=None):
182
+ from safetensors.torch import load_file
183
+ import json
184
+ if repo_path:
185
+ weights_path = os.path.join(repo_path, "model.safetensors")
186
+ config_path = os.path.join(repo_path, "config_hf.json")
187
+ else:
188
+ from huggingface_hub import hf_hub_download
189
+ weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
190
+ config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json")
191
+
192
+ # Load config values
193
+ with open(config_path, "r") as f:
194
+ config = json.load(f)
195
+
196
+ img_size = config["img_size"]
197
+ nc_out = len(config["targ_channel_names"])
198
+ use_attention = config["use_attention"]
199
+ hoptimus_hf_id = config["hoptimus_hf_id"]
200
+
201
+ vit = get_hoptimus0_hf(hoptimus_hf_id)
202
+ vit.set_input_size(img_size=(img_size, img_size))
203
+ encoder = Encoder(vit)
204
+ decoder = Detail_Capture(emb_chans=encoder.embed_dim, out_chans=nc_out, use_attention=use_attention, activation=nn.Tanh())
205
+ model = cls(encoder=encoder, decoder=decoder)
206
+ state_dict = load_file(weights_path)
207
+ state_dict = merge_lora_weights(model, state_dict)
208
+ load_info = model.load_state_dict(state_dict, strict=False)
209
+ validate_load_info(load_info)
210
+ model.eval()
211
+ return model
212
+
213
+ def set_input_size(self, img_size):
214
+ if any((s & (s - 1)) != 0 or s == 0 for s in img_size):
215
+ raise ValueError("Both height and width in img_size must be powers of 2")
216
+ if any(s < 128 for s in img_size):
217
+ raise ValueError("Height and width must be greater or equal to 128")
218
+ self.encoder.vit.set_input_size(img_size=img_size)
219
+ self.encoder.grid_size = self.encoder.vit.patch_embed.grid_size
220
+
221
+
222
+ class Encoder(nn.Module):
223
+ """
224
+ Wraps a Vision Transformer (ViT or Swin) to produce feature maps compatible
225
+ with U-Net-like architectures. It reshapes and resizes transformer outputs
226
+ into spatial feature maps.
227
+
228
+ Parameters:
229
+ -----------
230
+ vit : VisionTransformer or SwinTransformer
231
+ A pretrained transformer model from `timm` that outputs patch embeddings.
232
+ """
233
+ def __init__(self, vit):
234
+ super().__init__()
235
+ if not isinstance(vit, (VisionTransformer, SwinTransformer)):
236
+ raise ValueError(f"Expected a VisionTransformer or SwinTransformer, got {type(vit)}")
237
+ self.vit = vit
238
+
239
+ self.is_swint = isinstance(vit, SwinTransformer)
240
+ self.grid_size = self.vit.patch_embed.grid_size
241
+ if self.is_swint:
242
+ self.num_prefix_tokens = 0
243
+ self.embed_dim = self.vit.embed_dim * 2 ** (self.vit.num_layers -1)
244
+ else:
245
+ self.num_prefix_tokens = self.vit.num_prefix_tokens
246
+ self.embed_dim = self.vit.embed_dim
247
+ patch_size = self.vit.patch_embed.patch_size
248
+ img_size = self.vit.patch_embed.img_size
249
+ assert img_size[0] % 16 == 0
250
+ assert img_size[1] % 16 == 0
251
+
252
+ if self.is_swint:
253
+ self.scale_factor = (2., 2.)
254
+ else:
255
+ if patch_size != (16, 16):
256
+ target_grid_size = (img_size[0] / 16, img_size[1] / 16)
257
+ self.scale_factor = (target_grid_size[0] / self.grid_size[0], target_grid_size[1] / self.grid_size[1])
258
+ else:
259
+ self.scale_factor = None
260
+
261
+ def forward(self, x):
262
+ features = self.vit(x)
263
+ if self.is_swint:
264
+ features = features.permute(0, 3, 1, 2)
265
+ else:
266
+ features = features[:, self.num_prefix_tokens:]
267
+ features = features.permute(0, 2, 1)
268
+ features = features.view((-1, self.embed_dim, *self.grid_size))
269
+ if self.scale_factor is not None:
270
+ features = F.interpolate(features, scale_factor=self.scale_factor, mode="bicubic")
271
+ return features
272
+
273
+
274
+ class Detail_Capture(nn.Module):
275
+ """
276
+ Simple and Lightweight Detail Capture Module for ViT Matting.
277
+ """
278
+ def __init__(
279
+ self,
280
+ emb_chans,
281
+ in_chans=3,
282
+ out_chans=1,
283
+ convstream_out = [48, 96, 192],
284
+ fusion_out = [256, 128, 64, 32],
285
+ use_attention=True,
286
+ activation=torch.nn.Identity()
287
+ ):
288
+ super().__init__()
289
+ assert len(fusion_out) == len(convstream_out) + 1
290
+
291
+ self.convstream = ConvStream(in_chans=in_chans)
292
+ self.conv_chans = self.convstream.conv_chans
293
+ self.num_heads = out_chans
294
+
295
+ self.fusion_blks = nn.ModuleList()
296
+ self.fus_channs = fusion_out.copy()
297
+ self.fus_channs.insert(0, emb_chans)
298
+ for i in range(len(self.fus_channs)-1):
299
+ self.fusion_blks.append(
300
+ Fusion_Block(
301
+ in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)],
302
+ out_chans = self.fus_channs[i+1],
303
+ )
304
+ )
305
+
306
+ for idx in range(self.num_heads):
307
+ setattr(self, f'segmentation_head_{idx}', SegmentationHead(
308
+ in_channels=fusion_out[-1],
309
+ out_channels=1,
310
+ activation=activation,
311
+ kernel_size=3,
312
+ use_attention=use_attention
313
+ ))
314
+
315
+ def forward(self, features, images):
316
+ detail_features = self.convstream(images)
317
+ for i in range(len(self.fusion_blks)):
318
+ d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
319
+ features = self.fusion_blks[i](features, detail_features[d_name_])
320
+
321
+ outputs = []
322
+ for idx_head in range(self.num_heads):
323
+ segmentation_head = getattr(self, f'segmentation_head_{idx_head}')
324
+ output = segmentation_head(features)
325
+ outputs.append(output)
326
+ outputs = torch.cat(outputs, dim=1)
327
+
328
+ return outputs
329
+
330
+
331
+ def merge_lora_weights(model, state_dict, alpha=1.0, block_prefix="encoder.vit.blocks"):
332
+ """
333
+ Merges LoRA weights into the base attention Q and V projection weights for each transformer block.
334
+ We keep LoRA weights in the model.safetensors to avoid having the original foundation model weights in the repo.
335
+
336
+ Parameters:
337
+ -----------
338
+ model : torch.nn.Module
339
+ The model containing the transformer blocks to modify (e.g., ViT backbone).
340
+ state_dict : dict
341
+ The state_dict containing LoRA matrices with keys formatted as
342
+ '{block_prefix}.{idx}.attn.qkv.lora_q.A', etc.
343
+ This dict is modified in-place to remove LoRA weights after merging.
344
+ alpha : float, optional
345
+ Scaling factor for the LoRA update. Defaults to 1.0.
346
+ block_prefix : str, optional
347
+ Prefix to locate transformer blocks in the model. Defaults to "encoder.vit.blocks".
348
+
349
+ Returns:
350
+ --------
351
+ dict
352
+ The modified state_dict with LoRA weights removed after merging.
353
+ """
354
+ with torch.no_grad():
355
+ for idx in range(len(model.encoder.vit.blocks)):
356
+ prefix = f"{block_prefix}.{idx}.attn.qkv"
357
+
358
+ # Extract LoRA matrices
359
+ A_q = state_dict.pop(f"{prefix}.lora_q.A")
360
+ B_q = state_dict.pop(f"{prefix}.lora_q.B")
361
+ A_v = state_dict.pop(f"{prefix}.lora_v.A")
362
+ B_v = state_dict.pop(f"{prefix}.lora_v.B")
363
+
364
+ # Compute low-rank updates (transposed to match weight shape)
365
+ delta_q = (alpha * A_q @ B_q).T
366
+ delta_v = (alpha * A_v @ B_v).T
367
+
368
+ # Get original QKV weight matrix (shape: [3*dim, dim])
369
+ W = model.get_parameter(f"{prefix}.weight")
370
+ dim = delta_q.shape[0]
371
+ assert W.shape[0] == 3 * dim, f"Unexpected QKV shape: {W.shape}"
372
+
373
+ # Apply LoRA deltas to Q and V projections
374
+ W[:dim, :] += delta_q # Q projection
375
+ W[2 * dim:, :] += delta_v # V projection
376
+
377
+ return state_dict
378
+
379
+
380
+ def get_hoptimus0_hf(repo_id):
381
+ """ Hoptimus foundation model from hugginface repo id
382
+ """
383
+ model = timm.create_model(
384
+ "vit_giant_patch14_reg4_dinov2", img_size=224,
385
+ drop_path_rate=0., num_classes=0,
386
+ global_pool="", pretrained=False, init_values=1e-5,
387
+ dynamic_img_size=False)
388
+ state_dict = load_state_dict_from_hf(repo_id, weights_only=True)
389
+ model.load_state_dict(state_dict)
390
+ return model
391
+
392
+
393
+ def validate_load_info(load_info):
394
+ """
395
+ Validates the result of model.load_state_dict(..., strict=False).
396
+
397
+ Raises:
398
+ ValueError if unexpected keys are found,
399
+ or if missing keys are not related to the allowed encoder modules.
400
+ """
401
+ # 1. Raise if any unexpected keys
402
+ if load_info.unexpected_keys:
403
+ raise ValueError(f"Unexpected keys in state_dict: {load_info.unexpected_keys}")
404
+
405
+ # 2. Raise if any missing keys are not part of allowed encoder modules
406
+ for key in load_info.missing_keys:
407
+ if ".lora" in key:
408
+ raise ValueError(f"Missing LoRA checkpoint in state_dict: {key}")
409
+ elif not any(part in key for part in ["encoder.vit.", "encoder.model."]):
410
+ raise ValueError(f"Missing key in state_dict: {key}")
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d426fb3ad3635413ca93de3cc41529a191f70e6930fc5074e66a3da0d85fe43
3
+ size 26840896