tnk2908 commited on
Commit
b7e7adb
·
1 Parent(s): 0f59ee5
Files changed (2) hide show
  1. app.py +382 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import uuid
3
+ import zipfile
4
+ from copy import deepcopy
5
+ from functools import partial
6
+ from pathlib import Path
7
+
8
+ import gradio as gr
9
+ import h5py
10
+ import numpy as np
11
+ import torch
12
+ import torchvision.transforms.functional as F
13
+ from open_clip import create_model_from_pretrained, get_tokenizer
14
+ from PIL import Image
15
+ from torch.utils.data import ConcatDataset, DataLoader
16
+
17
+ from activelearning import KMeanSelector
18
+ from datasets import ActiveDataset, ExtendableDataset, ImageDataset
19
+ from models.unet import UNet, UnetProcessor
20
+ from utils import draw_mask
21
+
22
+ IMAGES_PER_ROW = 10
23
+ IMAGE_SIZE = 256
24
+ ROOT_DIR = Path(".")
25
+ DATA_DIR = ROOT_DIR / "data"
26
+
27
+ train_set = []
28
+ pool_set = []
29
+ current_dataset = "dataset"
30
+ feature_dict = None
31
+
32
+
33
+ class Config:
34
+ def __init__(self):
35
+ self.budget = 10
36
+ self.model = "BiomedCLIP"
37
+ self.device = torch.device("cpu")
38
+ self.batch_size = 4
39
+ self.loaded_feature_weight = 1
40
+ self.sharp_factor = 1
41
+ self.loaded_feature_only = False
42
+ self.model_ckpt = "./init_model.pth"
43
+
44
+
45
+ config = Config()
46
+
47
+
48
+ def build_foundation_model(device):
49
+ if config.model == "BiomedCLIP":
50
+ model, preprocess = create_model_from_pretrained(
51
+ "hf-hub:microsoft/biomedclip-pubmedbert_256-vit_base_patch16_224"
52
+ )
53
+ tokenizer = get_tokenizer("hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224")
54
+ model.to(device)
55
+ model.eval()
56
+ return model, preprocess
57
+ else:
58
+ raise RuntimeError()
59
+
60
+
61
+ def build_specialist_model():
62
+ model = UNet(
63
+ dimension=2,
64
+ input_channels=1,
65
+ output_classes=3,
66
+ channels_list=[32, 64, 128, 256, 512],
67
+ block_type="plain",
68
+ normalization="batch",
69
+ )
70
+ model_processor = UnetProcessor(image_size=(IMAGE_SIZE, IMAGE_SIZE))
71
+ return model, model_processor
72
+
73
+
74
+ specialist_model, specialist_processor = build_specialist_model()
75
+
76
+
77
+ def load_specialist_model(model_ckpt):
78
+ specialist_model.load_state_dict(torch.load(model_ckpt, map_location=torch.device("cpu"), weights_only=True))
79
+
80
+
81
+ def get_feature_dict(batch_size, device, active_dataset: ActiveDataset):
82
+ dataset = ConcatDataset([active_dataset.get_train_dataset(), active_dataset.get_pool_dataset()])
83
+ dataloader = DataLoader(dataset, batch_size=batch_size)
84
+
85
+ model, preprocess = build_foundation_model(device)
86
+ feature_dict = {}
87
+
88
+ for sampled_batch in dataloader:
89
+ image_batch = sampled_batch["image"]
90
+ image_list = []
91
+ for image in image_batch:
92
+ image_pil = F.to_pil_image(image).convert("RGB")
93
+ image_list.append(preprocess(image_pil))
94
+ image_batch = torch.stack(image_list, dim=0)
95
+ image_batch = image_batch.to(device)
96
+
97
+ with torch.no_grad():
98
+ feature_batch = model.encode_image(image_batch)
99
+
100
+ for i in range(len(feature_batch)):
101
+ case_name = sampled_batch["case_name"][i]
102
+ feature_dict[case_name] = feature_batch[i]
103
+
104
+ return feature_dict
105
+
106
+
107
+ def active_select(
108
+ train_set,
109
+ pool_set,
110
+ budget,
111
+ model_ckpt,
112
+ batch_size,
113
+ device,
114
+ loaded_feature_weight,
115
+ sharp_factor,
116
+ loaded_feature_only,
117
+ ):
118
+ global feature_dict
119
+ train_dataset = ExtendableDataset(ImageDataset(train_set, image_channels=1, image_size=IMAGE_SIZE))
120
+ pool_dataset = ExtendableDataset(ImageDataset(pool_set, image_channels=1, image_size=IMAGE_SIZE))
121
+ active_dataset = ActiveDataset(train_dataset, pool_dataset)
122
+ if feature_dict is None:
123
+ feature_dict = get_feature_dict(batch_size, device, active_dataset)
124
+
125
+ active_selector = KMeanSelector(
126
+ batch_size=4,
127
+ num_workers=1,
128
+ pin_memory=True,
129
+ metric="l2",
130
+ feature_dict=feature_dict,
131
+ loaded_feature_weight=loaded_feature_weight,
132
+ sharp_factor=sharp_factor,
133
+ loaded_feature_only=loaded_feature_only,
134
+ )
135
+ load_specialist_model(model_ckpt)
136
+ return active_selector.select_next_batch(active_dataset, budget, specialist_model, device)
137
+
138
+
139
+ def build_input_ui():
140
+ with gr.Accordion("Input") as blk:
141
+ with gr.Row():
142
+ train_gallery = gr.Gallery(
143
+ label="Train set", allow_preview=False, columns=IMAGES_PER_ROW // 2, show_label=True
144
+ )
145
+ pool_gallery = gr.Gallery(
146
+ label="Pool set", allow_preview=False, columns=IMAGES_PER_ROW // 2, show_label=True
147
+ )
148
+
149
+ def gallery_change(image_list, target_set=None):
150
+ global feature_dict
151
+ if image_list is None:
152
+ return
153
+
154
+ if target_set == "train":
155
+ global train_set
156
+ train_set = [x[0] for x in image_list]
157
+ feature_dict = None
158
+ elif target_set == "pool":
159
+ global pool_set
160
+ pool_set = [x[0] for x in image_list]
161
+ feature_dict = None
162
+
163
+ train_gallery.change(partial(gallery_change, target_set="train"), train_gallery, None)
164
+ pool_gallery.change(partial(gallery_change, target_set="pool"), pool_gallery, None)
165
+
166
+ return blk
167
+
168
+
169
+ def build_parameters_ui():
170
+ with gr.Accordion() as blk:
171
+ budget_input = gr.Number(config.budget, label="Budget")
172
+ model_ckpt_input = gr.Text(config.model_ckpt, label="Specialist Model Checkpoint")
173
+ device_input = gr.Dropdown(choices=["cuda", "cpu"], value="cpu", label="Device", interactive=True)
174
+ batch_size_input = gr.Number(config.batch_size, label="Batch Size")
175
+ foundation_model_weight_input = gr.Number(config.loaded_feature_weight, label="foundation_model_weight")
176
+ sharp_factor_input = gr.Number(config.sharp_factor, label="sharp_factor")
177
+
178
+ def budget_input_change(x):
179
+ config.budget = int(x)
180
+
181
+ budget_input.change(budget_input_change, budget_input, None)
182
+
183
+ def model_ckpt_input_change(x):
184
+ config.model_ckpt = x
185
+
186
+ model_ckpt_input.change(model_ckpt_input_change, model_ckpt_input, None)
187
+
188
+ def device_input_change(x):
189
+ config.device = torch.device(x)
190
+
191
+ device_input.change(device_input_change, device_input, None)
192
+
193
+ def batch_size_input_change(x):
194
+ config.batch_size = int(x)
195
+
196
+ batch_size_input.change(batch_size_input_change, batch_size_input, None)
197
+
198
+ def foundation_model_weight_input_change(x):
199
+ config.loaded_feature_weight = x
200
+
201
+ foundation_model_weight_input.change(foundation_model_weight_input_change, foundation_model_weight_input, None)
202
+
203
+ def sharp_factor_input_change(x):
204
+ config.sharp_factor = x
205
+
206
+ sharp_factor_input.change(sharp_factor_input_change, sharp_factor_input, None)
207
+ return blk
208
+
209
+
210
+ class_color_map = {
211
+ 1: "#ff0000",
212
+ 2: "#00ff00",
213
+ }
214
+ selected_image = None
215
+ selected_set = []
216
+ annotated_set = []
217
+
218
+
219
+ def predict_pseudo_label(image_pil):
220
+ image = F.to_tensor(image_pil)
221
+ image = image.unsqueeze(0)
222
+ _, _, H, W = image.shape
223
+ image = specialist_processor.preprocess(image)
224
+ with torch.no_grad():
225
+ pred = specialist_model(image)
226
+ pseudo_label = pred.argmax(1)
227
+ pseudo_label = specialist_processor.postprocess(pseudo_label, [H, W])
228
+
229
+ return pseudo_label[0]
230
+
231
+
232
+ def hex_to_rgb(h):
233
+ h = h[1:]
234
+ return [int(h[i : i + 2], 16) for i in range(0, 6, 2)]
235
+
236
+
237
+ def build_active_selection_ui():
238
+ with gr.Accordion("Active Selection") as blk:
239
+ select_button = gr.Button("Select")
240
+
241
+ with gr.Row():
242
+ selected_gallary = gr.Gallery(
243
+ label="Selected samples", allow_preview=False, columns=IMAGES_PER_ROW // 2, show_label=True
244
+ )
245
+ annotated_gallary = gr.Gallery(
246
+ label="Annotated samples",
247
+ allow_preview=True,
248
+ columns=IMAGES_PER_ROW // 2,
249
+ show_label=True,
250
+ interactive=False,
251
+ )
252
+
253
+ image_editor = gr.ImageEditor(
254
+ label="Image Editor",
255
+ interactive=True,
256
+ sources=(),
257
+ brush=gr.Brush(colors=[c for c in class_color_map.values()], color_mode="fixed"),
258
+ layers=False,
259
+ )
260
+ accept_button = gr.Button("Accept")
261
+
262
+ download_button = gr.DownloadButton(label="Download Annotated Dataset", visible=False)
263
+
264
+ def select_button_click():
265
+ global selected_set, current_dataset, train_set, pool_set, config, annotated_set
266
+ annotated_samples = [x["path"] for x in annotated_set]
267
+ selected_set = active_select(
268
+ list(set(train_set + annotated_samples)),
269
+ pool_set,
270
+ config.budget,
271
+ config.model_ckpt,
272
+ config.batch_size,
273
+ config.device,
274
+ config.loaded_feature_weight,
275
+ config.sharp_factor,
276
+ config.loaded_feature_only,
277
+ )
278
+ current_dataset = uuid.uuid4()
279
+ return selected_set
280
+
281
+ select_button.click(select_button_click, None, selected_gallary)
282
+
283
+
284
+ def get_editor_value(image_path):
285
+ image_pil = Image.open(image_path).convert("L")
286
+ background = np.array(image_pil.convert("RGBA"))
287
+ pseudo_label = predict_pseudo_label(image_pil).cpu().numpy()
288
+ layer = np.zeros_like(background)
289
+ for cl, color in class_color_map.items():
290
+ bin_mask = pseudo_label == cl
291
+ layer[bin_mask] = hex_to_rgb(color) + [255]
292
+
293
+ return {"background": background, "layers": [layer], "composite": None}
294
+
295
+ def gallery_select(data: gr.SelectData):
296
+ global selected_image
297
+ selected_image = {
298
+ "index": data.index,
299
+ "path": data.value["image"]["path"],
300
+ }
301
+ return get_editor_value(selected_image["path"])
302
+
303
+ selected_gallary.select(gallery_select, None, image_editor)
304
+
305
+ def accept_button_click(value):
306
+ global selected_set, selected_image, annotated_set
307
+ if len(value["layers"]) and selected_image:
308
+ layer_np = value["layers"][0]
309
+ binary_layer_np = np.zeros_like(layer_np)
310
+ binary_layer_np[layer_np > 127] = 255
311
+ H, W, _ = layer_np.shape
312
+ mask_np = np.zeros((H, W), dtype=np.uint8)
313
+ for cl, color in class_color_map.items():
314
+ color_rgb = hex_to_rgb(color)
315
+ bin_mask = np.all(binary_layer_np[:, :, :3] == color_rgb, axis=-1)
316
+ mask_np[bin_mask] = cl
317
+
318
+ selected_image["image"] = value["background"]
319
+ selected_image["mask"] = mask_np
320
+ image_pil = F.to_pil_image(value["background"]).convert("RGB")
321
+ selected_image["visual"] = draw_mask(image_pil, mask_np)
322
+
323
+ selected_set = [deepcopy(x) for x in selected_set if x != selected_image["path"]]
324
+ annotated_set.append(deepcopy(selected_image))
325
+ new_index = min(selected_image["index"], len(selected_set) - 1)
326
+ if new_index >= 0:
327
+ selected_image = {"index": new_index, "path": selected_set[new_index]}
328
+ image_editor = get_editor_value(selected_image["path"])
329
+ else:
330
+ selected_image = None
331
+ image_editor = None
332
+ else:
333
+ image_editor = None
334
+
335
+ _download_button = gr.DownloadButton(value=create_download_dataset(), visible=True)
336
+ return image_editor, selected_set, [x["visual"] for x in annotated_set], _download_button
337
+
338
+ accept_button.click(
339
+ accept_button_click, image_editor, [image_editor, selected_gallary, annotated_gallary, download_button]
340
+ )
341
+
342
+ return blk
343
+
344
+
345
+ def create_download_dataset():
346
+ dataset_dir = DATA_DIR / "dataset"
347
+ if dataset_dir.exists():
348
+ shutil.rmtree(dataset_dir)
349
+ dataset_dir.mkdir(exist_ok=True, parents=True)
350
+
351
+ images_dir = dataset_dir / "images"
352
+ labels_dir = dataset_dir / "labels"
353
+
354
+ images_dir.mkdir(exist_ok=True, parents=True)
355
+ labels_dir.mkdir(exist_ok=True, parents=True)
356
+
357
+ zip_file = DATA_DIR / "dataset.zip"
358
+
359
+ with zipfile.ZipFile(zip_file, "w") as archive:
360
+ for sample in annotated_set:
361
+ case_name = Path(sample["path"]).stem
362
+ image_np = sample["image"]
363
+ label_np = sample["mask"]
364
+
365
+ image_pil = Image.fromarray(image_np)
366
+ label_pil = Image.fromarray(label_np)
367
+
368
+ image_pil.save(images_dir / f"{case_name}.png")
369
+ label_pil.save(labels_dir / f"{case_name}.png")
370
+
371
+ archive.write(images_dir / f"{case_name}.png", arcname=f"images/{case_name}.png")
372
+ archive.write(labels_dir / f"{case_name}.png", arcname=f"labels/{case_name}.png")
373
+
374
+ return zip_file
375
+
376
+
377
+ if __name__ == "__main__":
378
+ with gr.Blocks() as demo:
379
+ input_ui = build_input_ui()
380
+ parameters_ui = build_parameters_ui()
381
+ active_selection_ui = build_active_selection_ui()
382
+ demo.launch(inbrowser=True)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git+https://github.com/trnKhanh/medical-image-analysis