U1020040 commited on
Commit
3047e70
·
1 Parent(s): fbbdc88

first commit

Browse files
Files changed (6) hide show
  1. .gitignore +3 -0
  2. LICENSE +9 -0
  3. app.py +59 -0
  4. logo.svg +232 -0
  5. model.py +410 -0
  6. requirements.txt +7 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.so
LICENSE ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ Non Commercial License Notice:
2
+
3
+ Copyright (c) 2025 Sanofi.
4
+
5
+ Permission is hereby granted, free of charge, for academic research purposes only and for non-commercial uses only, to any person from academic research or non-profit organizations obtaining a copy of this software and associated documentation files (the "Software"), to use, copy, modify, or merge the Software, subject to the following conditions: this permission notice shall be included in all copies of the Software or of substantial portions of the Software.
6
+
7
+ For purposes of this license, “non-commercial use” excludes uses foreseeably resulting in a commercial benefit. To use this software for other purposes (such as the development of a commercial product, including but not limited to software, service, or pharmaceuticals, or in a collaboration with a private company), please contact SANOFI at [email protected].
8
+
9
+ All other rights are reserved, including those for text and data mining, AI training and similar technologies. The Software is provided “as is”, without warranty of any kind, express or implied, including the warranties of noninfringement.
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import hf_hub_download
3
+ import torch
4
+ import json
5
+ from PIL import Image
6
+ import numpy as np
7
+ from model import MIPHEIViT
8
+
9
+ # Load model once
10
+ repo_id = "Estabousi/MIPHEI-vit"
11
+ model = MIPHEIViT.from_pretrained_hf(repo_id=repo_id)
12
+ config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json")
13
+
14
+ model.eval()
15
+ mean = torch.Tensor([0.485, 0.456, 0.406]).to(torch.float32).reshape((-1, 1, 1))
16
+ std = torch.Tensor([0.229, 0.224, 0.225]).to(torch.float32).reshape((-1, 1, 1))
17
+ with open(config_path, "r") as f:
18
+ config = json.load(f)
19
+ channel_names = config["targ_channel_names"]
20
+
21
+ def preprocess(image):
22
+ image = image.convert("RGB").resize((256, 256))
23
+ tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255
24
+ tensor = (tensor - mean) / std
25
+ return tensor.unsqueeze(0) # [1, 3, H, W]
26
+
27
+ def predict(image):
28
+ input_tensor = preprocess(image)
29
+ with torch.inference_mode():
30
+ output = model(input_tensor)[0] # [16, H, W]
31
+ output = (output.clamp(-0.9, 0.9) + 0.9) / 1.8
32
+ output = np.uint8(output.cpu().numpy() * 255)
33
+
34
+ # Convert each mIF channel to grayscale PIL image
35
+ channel_imgs = []
36
+ for i in range(output.shape[0]):
37
+ ch_img = output[i]
38
+ pil_ch = Image.fromarray(ch_img, mode='L')
39
+ channel_imgs.append(pil_ch)
40
+
41
+ # Return list: input + 16 channels
42
+ return channel_imgs
43
+
44
+ # Prepare Gradio UI
45
+ demo = gr.Interface(
46
+ fn=predict,
47
+ inputs=gr.Image(type="pil", label="Input H&E"),
48
+ outputs=[gr.Image(type="pil", label=f"mIF Channel {channel_names[i]}") for i in range(16)],
49
+ title="MIPHEI-ViT: Full mIF Prediction",
50
+ description=(
51
+ "Upload an H&E image tile (colorectal 256×256 pxs at 0.5 µm/px recommended). "
52
+ "The image will be resized to (256×256) if needed.\n"
53
+ "The model predicts 16-channel multiplex immunofluorescence, "
54
+ "with each marker shown as a grayscale image."
55
+ )
56
+ )
57
+
58
+ if __name__ == "__main__":
59
+ demo.launch()
logo.svg ADDED
model.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script defines the MIPHEI-ViT architecture for image-to-image translation
3
+ Some modules in this file are adapted from: https://github.com/hustvl/ViTMatte/
4
+ """
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import timm
10
+ from timm.models import VisionTransformer, SwinTransformer
11
+ from timm.models import load_state_dict_from_hf
12
+
13
+
14
+ class Basic_Conv3x3(nn.Module):
15
+ """
16
+ Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
17
+ https://github.com/hustvl/ViTMatte/blob/main/modeling/decoder/detail_capture.py#L5
18
+ """
19
+ def __init__(
20
+ self,
21
+ in_chans,
22
+ out_chans,
23
+ stride=2,
24
+ padding=1,
25
+ ):
26
+ super().__init__()
27
+ self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
28
+ self.bn = nn.BatchNorm2d(out_chans)
29
+ self.relu = nn.ReLU(inplace=False)
30
+
31
+ def forward(self, x):
32
+ x = self.conv(x)
33
+ x = self.bn(x)
34
+ x = self.relu(x)
35
+
36
+ return x
37
+
38
+
39
+ class ConvStream(nn.Module):
40
+ """
41
+ Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
42
+ """
43
+ def __init__(
44
+ self,
45
+ in_chans = 4,
46
+ out_chans = [48, 96, 192],
47
+ ):
48
+ super().__init__()
49
+ self.convs = nn.ModuleList()
50
+
51
+ self.conv_chans = out_chans.copy()
52
+ self.conv_chans.insert(0, in_chans)
53
+
54
+ for i in range(len(self.conv_chans)-1):
55
+ in_chan_ = self.conv_chans[i]
56
+ out_chan_ = self.conv_chans[i+1]
57
+ self.convs.append(
58
+ Basic_Conv3x3(in_chan_, out_chan_)
59
+ )
60
+
61
+ def forward(self, x):
62
+ out_dict = {'D0': x}
63
+ for i in range(len(self.convs)):
64
+ x = self.convs[i](x)
65
+ name_ = 'D'+str(i+1)
66
+ out_dict[name_] = x
67
+
68
+ return out_dict
69
+
70
+
71
+ class SegmentationHead(nn.Sequential):
72
+ # https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/base/heads.py#L5
73
+ def __init__(
74
+ self, in_channels, out_channels, kernel_size=3, activation=None, use_attention=False,
75
+ ):
76
+ if use_attention:
77
+ attention = AttentionBlock(in_channels)
78
+ else:
79
+ attention = nn.Identity()
80
+ conv2d = nn.Conv2d(
81
+ in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
82
+ )
83
+ activation = activation
84
+ super().__init__(attention, conv2d, activation)
85
+
86
+
87
+ class AttentionBlock(nn.Module):
88
+ """
89
+ Attention gate
90
+
91
+ Parameters:
92
+ -----------
93
+ in_chns : int
94
+ Number of input channels.
95
+
96
+ Forward Input:
97
+ --------------
98
+ x : torch.Tensor
99
+ Input tensor of shape [B, C, H, W].
100
+
101
+ Returns:
102
+ --------
103
+ torch.Tensor
104
+ Reweighted tensor of the same shape as input.
105
+ """
106
+ def __init__(self, in_chns):
107
+ super(AttentionBlock, self).__init__()
108
+ # Attention generation
109
+ self.psi = nn.Sequential(
110
+ nn.Conv2d(in_chns, in_chns // 2, kernel_size=1, stride=1, padding=0, bias=True),
111
+ nn.BatchNorm2d(in_chns // 2),
112
+ nn.ReLU(),
113
+ nn.Conv2d(in_chns // 2, 1, kernel_size=1, stride=1, padding=0, bias=True),
114
+ nn.Sigmoid()
115
+ )
116
+
117
+ def forward(self, x):
118
+ # Project decoder output to intermediate space
119
+ g = self.psi(x)
120
+ return x * g
121
+
122
+
123
+ class Fusion_Block(nn.Module):
124
+ """
125
+ Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer.
126
+ """
127
+ def __init__(
128
+ self,
129
+ in_chans,
130
+ out_chans,
131
+ ):
132
+ super().__init__()
133
+ self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1)
134
+
135
+ def forward(self, x, D):
136
+ F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) ## Nearest ?
137
+ out = torch.cat([D, F_up], dim=1)
138
+ out = self.conv(out)
139
+
140
+ return out
141
+
142
+
143
+ class MIPHEIViT(nn.Module):
144
+ """
145
+ U-Net-style architecture inspired by ViTMatte, using a Vision Transformer (ViT or Swin)
146
+ as encoder and a convolutional decoder. Designed for dense image prediction tasks,
147
+ such as image-to-image translation.
148
+
149
+ Parameters:
150
+ -----------
151
+ encoder : nn.Module
152
+ A ViT- or Swin-based encoder that outputs spatial feature maps.
153
+ decoder : nn.Module
154
+ A decoder module that maps encoder features (and optionally the original image)
155
+ to the output prediction.
156
+
157
+ Example:
158
+ --------
159
+ model = MIPHEIViT(encoder=Encoder(vit), decoder=UNetDecoder())
160
+ output = model(input_tensor)
161
+ """
162
+ def __init__(self,
163
+ encoder,
164
+ decoder,
165
+ ):
166
+ super(MIPHEIViT, self).__init__()
167
+ self.encoder = encoder
168
+ self.decoder = decoder
169
+ self.initialize()
170
+
171
+ def forward(self, x):
172
+
173
+ features = self.encoder(x)
174
+ outputs = self.decoder(features, x)
175
+ return outputs
176
+
177
+ def initialize(self):
178
+ pass
179
+
180
+ @classmethod
181
+ def from_pretrained_hf(cls, repo_path=None, repo_id=None):
182
+ from safetensors.torch import load_file
183
+ import json
184
+ if repo_path:
185
+ weights_path = os.path.join(repo_path, "model.safetensors")
186
+ config_path = os.path.join(repo_path, "config_hf.json")
187
+ else:
188
+ from huggingface_hub import hf_hub_download
189
+ weights_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
190
+ config_path = hf_hub_download(repo_id=repo_id, filename="config_hf.json")
191
+
192
+ # Load config values
193
+ with open(config_path, "r") as f:
194
+ config = json.load(f)
195
+
196
+ img_size = config["img_size"]
197
+ nc_out = len(config["targ_channel_names"])
198
+ use_attention = config["use_attention"]
199
+ hoptimus_hf_id = config["hoptimus_hf_id"]
200
+
201
+ vit = get_hoptimus0_hf(hoptimus_hf_id)
202
+ vit.set_input_size(img_size=(img_size, img_size))
203
+ encoder = Encoder(vit)
204
+ decoder = Detail_Capture(emb_chans=encoder.embed_dim, out_chans=nc_out, use_attention=use_attention, activation=nn.Tanh())
205
+ model = cls(encoder=encoder, decoder=decoder)
206
+ state_dict = load_file(weights_path)
207
+ state_dict = merge_lora_weights(model, state_dict)
208
+ load_info = model.load_state_dict(state_dict, strict=False)
209
+ validate_load_info(load_info)
210
+ model.eval()
211
+ return model
212
+
213
+ def set_input_size(self, img_size):
214
+ if any((s & (s - 1)) != 0 or s == 0 for s in img_size):
215
+ raise ValueError("Both height and width in img_size must be powers of 2")
216
+ if any(s < 128 for s in img_size):
217
+ raise ValueError("Height and width must be greater or equal to 128")
218
+ self.encoder.vit.set_input_size(img_size=img_size)
219
+ self.encoder.grid_size = self.encoder.vit.patch_embed.grid_size
220
+
221
+
222
+ class Encoder(nn.Module):
223
+ """
224
+ Wraps a Vision Transformer (ViT or Swin) to produce feature maps compatible
225
+ with U-Net-like architectures. It reshapes and resizes transformer outputs
226
+ into spatial feature maps.
227
+
228
+ Parameters:
229
+ -----------
230
+ vit : VisionTransformer or SwinTransformer
231
+ A pretrained transformer model from `timm` that outputs patch embeddings.
232
+ """
233
+ def __init__(self, vit):
234
+ super().__init__()
235
+ if not isinstance(vit, (VisionTransformer, SwinTransformer)):
236
+ raise ValueError(f"Expected a VisionTransformer or SwinTransformer, got {type(vit)}")
237
+ self.vit = vit
238
+
239
+ self.is_swint = isinstance(vit, SwinTransformer)
240
+ self.grid_size = self.vit.patch_embed.grid_size
241
+ if self.is_swint:
242
+ self.num_prefix_tokens = 0
243
+ self.embed_dim = self.vit.embed_dim * 2 ** (self.vit.num_layers -1)
244
+ else:
245
+ self.num_prefix_tokens = self.vit.num_prefix_tokens
246
+ self.embed_dim = self.vit.embed_dim
247
+ patch_size = self.vit.patch_embed.patch_size
248
+ img_size = self.vit.patch_embed.img_size
249
+ assert img_size[0] % 16 == 0
250
+ assert img_size[1] % 16 == 0
251
+
252
+ if self.is_swint:
253
+ self.scale_factor = (2., 2.)
254
+ else:
255
+ if patch_size != (16, 16):
256
+ target_grid_size = (img_size[0] / 16, img_size[1] / 16)
257
+ self.scale_factor = (target_grid_size[0] / self.grid_size[0], target_grid_size[1] / self.grid_size[1])
258
+ else:
259
+ self.scale_factor = None
260
+
261
+ def forward(self, x):
262
+ features = self.vit(x)
263
+ if self.is_swint:
264
+ features = features.permute(0, 3, 1, 2)
265
+ else:
266
+ features = features[:, self.num_prefix_tokens:]
267
+ features = features.permute(0, 2, 1)
268
+ features = features.view((-1, self.embed_dim, *self.grid_size))
269
+ if self.scale_factor is not None:
270
+ features = F.interpolate(features, scale_factor=self.scale_factor, mode="bicubic")
271
+ return features
272
+
273
+
274
+ class Detail_Capture(nn.Module):
275
+ """
276
+ Simple and Lightweight Detail Capture Module for ViT Matting.
277
+ """
278
+ def __init__(
279
+ self,
280
+ emb_chans,
281
+ in_chans=3,
282
+ out_chans=1,
283
+ convstream_out = [48, 96, 192],
284
+ fusion_out = [256, 128, 64, 32],
285
+ use_attention=True,
286
+ activation=torch.nn.Identity()
287
+ ):
288
+ super().__init__()
289
+ assert len(fusion_out) == len(convstream_out) + 1
290
+
291
+ self.convstream = ConvStream(in_chans=in_chans)
292
+ self.conv_chans = self.convstream.conv_chans
293
+ self.num_heads = out_chans
294
+
295
+ self.fusion_blks = nn.ModuleList()
296
+ self.fus_channs = fusion_out.copy()
297
+ self.fus_channs.insert(0, emb_chans)
298
+ for i in range(len(self.fus_channs)-1):
299
+ self.fusion_blks.append(
300
+ Fusion_Block(
301
+ in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)],
302
+ out_chans = self.fus_channs[i+1],
303
+ )
304
+ )
305
+
306
+ for idx in range(self.num_heads):
307
+ setattr(self, f'segmentation_head_{idx}', SegmentationHead(
308
+ in_channels=fusion_out[-1],
309
+ out_channels=1,
310
+ activation=activation,
311
+ kernel_size=3,
312
+ use_attention=use_attention
313
+ ))
314
+
315
+ def forward(self, features, images):
316
+ detail_features = self.convstream(images)
317
+ for i in range(len(self.fusion_blks)):
318
+ d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
319
+ features = self.fusion_blks[i](features, detail_features[d_name_])
320
+
321
+ outputs = []
322
+ for idx_head in range(self.num_heads):
323
+ segmentation_head = getattr(self, f'segmentation_head_{idx_head}')
324
+ output = segmentation_head(features)
325
+ outputs.append(output)
326
+ outputs = torch.cat(outputs, dim=1)
327
+
328
+ return outputs
329
+
330
+
331
+ def merge_lora_weights(model, state_dict, alpha=1.0, block_prefix="encoder.vit.blocks"):
332
+ """
333
+ Merges LoRA weights into the base attention Q and V projection weights for each transformer block.
334
+ We keep LoRA weights in the model.safetensors to avoid having the original foundation model weights in the repo.
335
+
336
+ Parameters:
337
+ -----------
338
+ model : torch.nn.Module
339
+ The model containing the transformer blocks to modify (e.g., ViT backbone).
340
+ state_dict : dict
341
+ The state_dict containing LoRA matrices with keys formatted as
342
+ '{block_prefix}.{idx}.attn.qkv.lora_q.A', etc.
343
+ This dict is modified in-place to remove LoRA weights after merging.
344
+ alpha : float, optional
345
+ Scaling factor for the LoRA update. Defaults to 1.0.
346
+ block_prefix : str, optional
347
+ Prefix to locate transformer blocks in the model. Defaults to "encoder.vit.blocks".
348
+
349
+ Returns:
350
+ --------
351
+ dict
352
+ The modified state_dict with LoRA weights removed after merging.
353
+ """
354
+ with torch.no_grad():
355
+ for idx in range(len(model.encoder.vit.blocks)):
356
+ prefix = f"{block_prefix}.{idx}.attn.qkv"
357
+
358
+ # Extract LoRA matrices
359
+ A_q = state_dict.pop(f"{prefix}.lora_q.A")
360
+ B_q = state_dict.pop(f"{prefix}.lora_q.B")
361
+ A_v = state_dict.pop(f"{prefix}.lora_v.A")
362
+ B_v = state_dict.pop(f"{prefix}.lora_v.B")
363
+
364
+ # Compute low-rank updates (transposed to match weight shape)
365
+ delta_q = (alpha * A_q @ B_q).T
366
+ delta_v = (alpha * A_v @ B_v).T
367
+
368
+ # Get original QKV weight matrix (shape: [3*dim, dim])
369
+ W = model.get_parameter(f"{prefix}.weight")
370
+ dim = delta_q.shape[0]
371
+ assert W.shape[0] == 3 * dim, f"Unexpected QKV shape: {W.shape}"
372
+
373
+ # Apply LoRA deltas to Q and V projections
374
+ W[:dim, :] += delta_q # Q projection
375
+ W[2 * dim:, :] += delta_v # V projection
376
+
377
+ return state_dict
378
+
379
+
380
+ def get_hoptimus0_hf(repo_id):
381
+ """ Hoptimus foundation model from hugginface repo id
382
+ """
383
+ model = timm.create_model(
384
+ "vit_giant_patch14_reg4_dinov2", img_size=224,
385
+ drop_path_rate=0., num_classes=0,
386
+ global_pool="", pretrained=False, init_values=1e-5,
387
+ dynamic_img_size=False)
388
+ state_dict = load_state_dict_from_hf(repo_id, weights_only=True)
389
+ model.load_state_dict(state_dict)
390
+ return model
391
+
392
+
393
+ def validate_load_info(load_info):
394
+ """
395
+ Validates the result of model.load_state_dict(..., strict=False).
396
+
397
+ Raises:
398
+ ValueError if unexpected keys are found,
399
+ or if missing keys are not related to the allowed encoder modules.
400
+ """
401
+ # 1. Raise if any unexpected keys
402
+ if load_info.unexpected_keys:
403
+ raise ValueError(f"Unexpected keys in state_dict: {load_info.unexpected_keys}")
404
+
405
+ # 2. Raise if any missing keys are not part of allowed encoder modules
406
+ for key in load_info.missing_keys:
407
+ if ".lora" in key:
408
+ raise ValueError(f"Missing LoRA checkpoint in state_dict: {key}")
409
+ elif not any(part in key for part in ["encoder.vit.", "encoder.model."]):
410
+ raise ValueError(f"Missing key in state_dict: {key}")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ timm
3
+ gradio
4
+ safetensors
5
+ numpy
6
+ Pillow
7
+ huggingface_hub