U1020040 commited on
Commit
0fb9984
·
1 Parent(s): 4cffc64

example, color, contrast

Browse files
HEADER.md CHANGED
@@ -1,6 +1,9 @@
1
  # MIPHEI-ViT Demo: 16-channel mIF Prediction
2
 
3
  <p align="center">
 
 
 
4
  <a href="https://huggingface.co/Estabousi/MIPHEI-vit" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
5
  <img src="https://img.shields.io/badge/🤗 Model-MIPHEI--ViT-lightgrey?logo=huggingface" height="25">
6
  </a>
@@ -17,7 +20,7 @@ The model returns **16 grayscale images**, each representing a predicted mIF mar
17
 
18
  ---
19
 
20
- Try it with low-zoom screenshots from public datasets:
21
 
22
  **ORION (in-domain test set):**
23
  - [CRC2](https://labsyspharm.github.io/orion-crc/minerva/P37_S30-CRC02/index.html#s=0&w=0&g=5&m=-1&a=-100_-100&v=1.0673_0.6057_0.5&o=-100_-100_1_1&p=Q)
 
1
  # MIPHEI-ViT Demo: 16-channel mIF Prediction
2
 
3
  <p align="center">
4
+ <a title="arXiv" href="https://arxiv.org/abs/2505.10294" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
5
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
6
+ </a>
7
  <a href="https://huggingface.co/Estabousi/MIPHEI-vit" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
8
  <img src="https://img.shields.io/badge/🤗 Model-MIPHEI--ViT-lightgrey?logo=huggingface" height="25">
9
  </a>
 
20
 
21
  ---
22
 
23
+ Try it with **provided examples** or low-zoom screenshots from public datasets:
24
 
25
  **ORION (in-domain test set):**
26
  - [CRC2](https://labsyspharm.github.io/orion-crc/minerva/P37_S30-CRC02/index.html#s=0&w=0&g=5&m=-1&a=-100_-100&v=1.0673_0.6057_0.5&o=-100_-100_1_1&p=Q)
app.py CHANGED
@@ -18,25 +18,63 @@ with open(config_path, "r") as f:
18
  config = json.load(f)
19
  channel_names = config["targ_channel_names"]
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def preprocess(image):
22
  image = image.convert("RGB").resize((256, 256))
23
  tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255
24
  tensor = (tensor - mean) / std
25
  return tensor.unsqueeze(0) # [1, 3, H, W]
26
 
 
 
 
 
 
 
27
  def predict(image):
28
  input_tensor = preprocess(image)
29
  with torch.inference_mode():
30
  output = model(input_tensor)[0] # [16, H, W]
31
  output = (output.clamp(-0.9, 0.9) + 0.9) / 1.8
32
- output = np.uint8(output.cpu().numpy() * 255)
 
 
 
33
 
34
  # Convert each mIF channel to grayscale PIL image
35
  channel_imgs = []
36
- for i in range(output.shape[0]):
37
- ch_img = output[i]
38
- pil_ch = Image.fromarray(ch_img, mode='L')
39
- channel_imgs.append(pil_ch)
 
40
 
41
  # Return predicted 16 channels
42
  return channel_imgs
@@ -47,14 +85,37 @@ with open("HEADER.md", "r", encoding="utf-8") as f:
47
 
48
  # Build interface using Blocks
49
  with gr.Blocks() as demo:
 
50
  gr.Markdown(HEADER_MD)
51
- gr.Interface(
52
- fn=predict,
53
- inputs=gr.Image(type="pil", label="Input H&E"),
54
- outputs=[gr.Image(type="pil", label=f"mIF Channel {channel_names[i]}") for i in range(16)],
55
- title=None,
56
- description=None
57
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  if __name__ == "__main__":
60
  demo.launch()
 
18
  config = json.load(f)
19
  channel_names = config["targ_channel_names"]
20
 
21
+
22
+ channel_colors = {
23
+ "Hoechst": (0, 0, 255), # Blue (DAPI, nuclear stain)
24
+ "CD31": (0, 255, 255), # Cyan (endothelial)
25
+ "CD45": (255, 255, 0), # Yellow (leukocyte common antigen)
26
+ "CD68": (255, 165, 0), # Orange (macrophages)
27
+ "CD4": (255, 0, 0), # Red (helper T cells)
28
+ "FOXP3": (138, 43, 226), # Purple/Blue-Violet (regulatory T cells)
29
+ "CD8a": (303, 100, 100), # Green (cytotoxic T cells)
30
+ "CD45RO": (255, 105, 180), # Hot Pink (memory T cells)
31
+ "CD20": (0, 191, 255), # Deep Sky Blue (B cells)
32
+ "PD-L1": (255, 0, 255), # Magenta
33
+ "CD3e": (95, 95, 94), # Crimson (T cells)
34
+ "CD163": (184, 134, 11), # Dark Goldenrod (M2 macrophages)
35
+ "E-cadherin": (242, 12, 43), # Spring Green (epithelial marker)
36
+ "Ki67": (255, 20, 147), # Deep Pink (proliferation marker)
37
+ "Pan-CK": (255, 0, 0), # Red (epithelial/carcinoma)
38
+ "SMA": (0, 255, 0), # Green (smooth muscle, myofibroblasts)
39
+ }
40
+
41
+ # Contrast correction factors per channel (255 for Hoechst, 150 otherwise)
42
+ default_contrast = 150.0
43
+ correction_map = {"Hoechst": 255.0, "CD8a": 100, "CD31": 100, "CD4": 100, "CD68": 100, "FOXP3": 100}
44
+ max_contrast_correction_value = torch.tensor([
45
+ correction_map.get(name, default_contrast) / 255 for name in channel_names
46
+ ]).reshape(len(channel_names), 1, 1)
47
+
48
+
49
  def preprocess(image):
50
  image = image.convert("RGB").resize((256, 256))
51
  tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255
52
  tensor = (tensor - mean) / std
53
  return tensor.unsqueeze(0) # [1, 3, H, W]
54
 
55
+ def apply_color_map(gray_img, rgb_color):
56
+ """Map a grayscale image to RGB using a fixed pseudocolor."""
57
+ gray = np.asarray(gray_img).astype(np.float32) / 255.0
58
+ rgb = np.stack([gray * rgb_color[i] for i in range(3)], axis=-1).astype(np.uint8)
59
+ return Image.fromarray(rgb, mode='RGB')
60
+
61
  def predict(image):
62
  input_tensor = preprocess(image)
63
  with torch.inference_mode():
64
  output = model(input_tensor)[0] # [16, H, W]
65
  output = (output.clamp(-0.9, 0.9) + 0.9) / 1.8
66
+ output_vis = output / max_contrast_correction_value.to(output.device).clamp(min=1e-6)
67
+ output_vis = output_vis.clamp(0, 1) * 255
68
+ output_vis = np.uint8(output_vis.cpu().numpy())
69
+ output = output.cpu().numpy()
70
 
71
  # Convert each mIF channel to grayscale PIL image
72
  channel_imgs = []
73
+ for i in range(output_vis.shape[0]):
74
+ ch_name = channel_names[i]
75
+ ch_gray = Image.fromarray(output_vis[i], mode='L')
76
+ ch_colored = apply_color_map(ch_gray, channel_colors[ch_name])
77
+ channel_imgs.append(ch_colored)
78
 
79
  # Return predicted 16 channels
80
  return channel_imgs
 
85
 
86
  # Build interface using Blocks
87
  with gr.Blocks() as demo:
88
+
89
  gr.Markdown(HEADER_MD)
90
+
91
+ with gr.Row():
92
+ # LEFT: input + examples + button
93
+ with gr.Column(scale=0.5):
94
+ input_image = gr.Image(type="pil", label="Input H&E")
95
+ run_btn = gr.Button("Run Prediction")
96
+ gr.Examples(
97
+ examples=[
98
+ ["examples/crc100k_val.jpg"],
99
+ ["examples/orion_test_1.jpg"],
100
+ ["examples/orion_test_2.jpg"],
101
+ ["examples/orion_test_3.jpg"],
102
+ ["examples/orion_test_4.jpg"],
103
+ ["examples/orion_test_5.jpg"],
104
+ ["examples/tcga.jpg"],
105
+ ["examples/hemit.jpg"],
106
+ ],
107
+ inputs=[input_image],
108
+ label="Example H&E tile (TCGA, ORION Test, CRC100K, HEMIT)"
109
+ )
110
+
111
+ # RIGHT: outputs
112
+ with gr.Column():
113
+ output_images = [
114
+ gr.Image(type="pil", label=f"mIF Channel {channel_names[i]}")
115
+ for i in range(16)
116
+ ]
117
+
118
+ run_btn.click(fn=predict, inputs=input_image, outputs=output_images)
119
 
120
  if __name__ == "__main__":
121
  demo.launch()
examples/crc100k_val.jpg ADDED
examples/hemit.jpg ADDED
examples/orion_test_1.jpg ADDED
examples/orion_test_2.jpg ADDED
examples/orion_test_3.jpg ADDED
examples/orion_test_4.jpg ADDED
examples/orion_test_5.jpg ADDED
examples/tcga.jpg ADDED