U1020040 commited on
Commit
ba0c28c
·
1 Parent(s): 68482bc

overlay version

Browse files
Files changed (1) hide show
  1. app.py +52 -3
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
 
2
  from huggingface_hub import hf_hub_download
3
  import torch
4
  import json
5
  from PIL import Image
 
6
  import numpy as np
7
  from model import MIPHEIViT
8
 
@@ -44,6 +46,7 @@ correction_map = {"Hoechst": 255.0, "CD8a": 100, "CD31": 100, "CD4": 100, "CD68"
44
  max_contrast_correction_value = torch.tensor([
45
  correction_map.get(name, default_contrast) / 255 for name in channel_names
46
  ]).reshape(len(channel_names), 1, 1)
 
47
 
48
 
49
  def preprocess(image):
@@ -52,6 +55,47 @@ def preprocess(image):
52
  tensor = (tensor - mean) / std
53
  return tensor.unsqueeze(0) # [1, 3, H, W]
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def apply_color_map(gray_img, rgb_color):
56
  """Map a grayscale image to RGB using a fixed pseudocolor."""
57
  gray = np.asarray(gray_img).astype(np.float32) / 255.0
@@ -59,6 +103,7 @@ def apply_color_map(gray_img, rgb_color):
59
  return Image.fromarray(rgb, mode='RGB')
60
 
61
  def predict(image):
 
62
  input_tensor = preprocess(image)
63
  with torch.inference_mode():
64
  output = model(input_tensor)[0] # [16, H, W]
@@ -76,8 +121,10 @@ def predict(image):
76
  ch_colored = apply_color_map(ch_gray, channel_colors[ch_name])
77
  channel_imgs.append(ch_colored)
78
 
79
- # Return predicted 16 channels
80
- return channel_imgs
 
 
81
 
82
  # Markdown header
83
  with open("HEADER.md", "r", encoding="utf-8") as f:
@@ -110,10 +157,12 @@ with gr.Blocks() as demo:
110
 
111
  # RIGHT: outputs
112
  with gr.Column(scale=2):
113
- output_images = [
 
114
  gr.Image(type="pil", label=f"mIF Channel {channel_names[i]}")
115
  for i in range(16)
116
  ]
 
117
 
118
  run_btn.click(fn=predict, inputs=input_image, outputs=output_images)
119
 
 
1
  import gradio as gr
2
+ from datetime import datetime
3
  from huggingface_hub import hf_hub_download
4
  import torch
5
  import json
6
  from PIL import Image
7
+ from PIL import ImageDraw, ImageFont
8
  import numpy as np
9
  from model import MIPHEIViT
10
 
 
46
  max_contrast_correction_value = torch.tensor([
47
  correction_map.get(name, default_contrast) / 255 for name in channel_names
48
  ]).reshape(len(channel_names), 1, 1)
49
+ overlay_markers = ["Hoechst", "Pan-CK", "SMA", "CD45"]
50
 
51
 
52
  def preprocess(image):
 
55
  tensor = (tensor - mean) / std
56
  return tensor.unsqueeze(0) # [1, 3, H, W]
57
 
58
+
59
+ def draw_legend_on_image(image, channel_names, channel_colors, indices, box_size=18, spacing=5, top_margin=5):
60
+ """Draw a semi-transparent legend on the bottom-right of the image."""
61
+ overlay = image.convert("RGBA") # to allow alpha
62
+ legend_layer = Image.new("RGBA", overlay.size, (255, 255, 255, 0))
63
+ draw = ImageDraw.Draw(legend_layer)
64
+ font = ImageFont.load_default()
65
+
66
+ legend_height = top_margin + box_size * len(indices) + spacing * (len(indices) - 1)
67
+ legend_width = 60 # adjust as needed
68
+ x_start = overlay.width - legend_width - 10
69
+ y_start = overlay.height - legend_height - 10
70
+
71
+ # Semi-transparent background
72
+ draw.rectangle(
73
+ [x_start - 5, y_start - 5, x_start + legend_width + 5, y_start + legend_height + 5],
74
+ fill=(255, 255, 255, 180) # semi-transparent white
75
+ )
76
+
77
+ for i, idx in enumerate(indices):
78
+ name = channel_names[idx]
79
+ color = channel_colors[name]
80
+ y = y_start + i * (box_size + spacing)
81
+ draw.rectangle([x_start, y, x_start + box_size, y + box_size], fill=color + (255,))
82
+ draw.text((x_start + box_size + 5, y), name, fill=(0, 0, 0, 255), font=font)
83
+
84
+ # Composite legend onto overlay
85
+ combined = Image.alpha_composite(overlay, legend_layer)
86
+ return combined.convert("RGB") # back to RGB for display
87
+
88
+
89
+ def merge_colored_images(color_imgs, top4_idx):
90
+ # Convert images to float32 NumPy arrays
91
+ accum = np.zeros_like(np.array(color_imgs[0]), dtype=np.float32)
92
+ for idx in top4_idx:
93
+ img = np.array(color_imgs[idx]).astype(np.float32)
94
+ accum += img # additive blending
95
+
96
+ accum = np.clip(accum, 0, 255).astype(np.uint8)
97
+ return Image.fromarray(accum, mode='RGB')
98
+
99
  def apply_color_map(gray_img, rgb_color):
100
  """Map a grayscale image to RGB using a fixed pseudocolor."""
101
  gray = np.asarray(gray_img).astype(np.float32) / 255.0
 
103
  return Image.fromarray(rgb, mode='RGB')
104
 
105
  def predict(image):
106
+ print(f"[{datetime.now().isoformat()}] Inference run")
107
  input_tensor = preprocess(image)
108
  with torch.inference_mode():
109
  output = model(input_tensor)[0] # [16, H, W]
 
121
  ch_colored = apply_color_map(ch_gray, channel_colors[ch_name])
122
  channel_imgs.append(ch_colored)
123
 
124
+ fixed_idx = [channel_names.index(name) for name in overlay_markers]
125
+ overlay = merge_colored_images(channel_imgs, fixed_idx)
126
+ overlay_with_legend = draw_legend_on_image(overlay, channel_names, channel_colors, fixed_idx)
127
+ return [overlay_with_legend] + channel_imgs
128
 
129
  # Markdown header
130
  with open("HEADER.md", "r", encoding="utf-8") as f:
 
157
 
158
  # RIGHT: outputs
159
  with gr.Column(scale=2):
160
+ overlay_image = gr.Image(type="pil", label="mIF Overlay")
161
+ channel_images = [
162
  gr.Image(type="pil", label=f"mIF Channel {channel_names[i]}")
163
  for i in range(16)
164
  ]
165
+ output_images = [overlay_image] + channel_images
166
 
167
  run_btn.click(fn=predict, inputs=input_image, outputs=output_images)
168