Spaces:
Sleeping
Sleeping
import shutil | |
import uuid | |
import zipfile | |
from copy import deepcopy | |
from functools import partial | |
from pathlib import Path | |
import gradio as gr | |
import h5py | |
import numpy as np | |
import torch | |
import torchvision.transforms.functional as F | |
from open_clip import create_model_from_pretrained, get_tokenizer | |
from PIL import Image | |
from torch.utils.data import ConcatDataset, DataLoader | |
from activelearning import KMeanSelector | |
from datasets import ActiveDataset, ExtendableDataset, ImageDataset | |
from models.unet import UNet, UnetProcessor | |
from utils import draw_mask | |
IMAGES_PER_ROW = 10 | |
IMAGE_SIZE = 256 | |
ROOT_DIR = Path(".") | |
DATA_DIR = ROOT_DIR / "data" | |
train_set = [] | |
pool_set = [] | |
current_dataset = "dataset" | |
feature_dict = None | |
class Config: | |
def __init__(self): | |
self.budget = 10 | |
self.model = "BiomedCLIP" | |
self.device = torch.device("cpu") | |
self.batch_size = 4 | |
self.loaded_feature_weight = 1 | |
self.sharp_factor = 1 | |
self.loaded_feature_only = False | |
self.model_ckpt = "./init_model.pth" | |
config = Config() | |
def build_foundation_model(device): | |
if config.model == "BiomedCLIP": | |
model, preprocess = create_model_from_pretrained( | |
"hf-hub:microsoft/biomedclip-pubmedbert_256-vit_base_patch16_224" | |
) | |
tokenizer = get_tokenizer("hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224") | |
model.to(device) | |
model.eval() | |
return model, preprocess | |
else: | |
raise RuntimeError() | |
def build_specialist_model(): | |
model = UNet( | |
dimension=2, | |
input_channels=1, | |
output_classes=3, | |
channels_list=[32, 64, 128, 256, 512], | |
block_type="plain", | |
normalization="batch", | |
) | |
model_processor = UnetProcessor(image_size=(IMAGE_SIZE, IMAGE_SIZE)) | |
return model, model_processor | |
specialist_model, specialist_processor = build_specialist_model() | |
def load_specialist_model(model_ckpt): | |
try: | |
specialist_model.load_state_dict(torch.load(model_ckpt, map_location=torch.device("cpu"), weights_only=True)) | |
print(f"Loaded {model_ckpt}") | |
except: | |
print(f"Failed to load {model_ckpt}") | |
def get_feature_dict(batch_size, device, active_dataset: ActiveDataset): | |
dataset = ConcatDataset([active_dataset.get_train_dataset(), active_dataset.get_pool_dataset()]) | |
dataloader = DataLoader(dataset, batch_size=batch_size) | |
model, preprocess = build_foundation_model(device) | |
feature_dict = {} | |
for sampled_batch in dataloader: | |
image_batch = sampled_batch["image"] | |
image_list = [] | |
for image in image_batch: | |
image_pil = F.to_pil_image(image).convert("RGB") | |
image_list.append(preprocess(image_pil)) | |
image_batch = torch.stack(image_list, dim=0) | |
image_batch = image_batch.to(device) | |
with torch.no_grad(): | |
feature_batch = model.encode_image(image_batch) | |
for i in range(len(feature_batch)): | |
case_name = sampled_batch["case_name"][i] | |
feature_dict[case_name] = feature_batch[i] | |
return feature_dict | |
def active_select( | |
train_set, | |
pool_set, | |
budget, | |
model_ckpt, | |
batch_size, | |
device, | |
loaded_feature_weight, | |
sharp_factor, | |
loaded_feature_only, | |
): | |
global feature_dict | |
train_dataset = ExtendableDataset(ImageDataset(train_set, image_channels=1, image_size=IMAGE_SIZE)) | |
pool_dataset = ExtendableDataset(ImageDataset(pool_set, image_channels=1, image_size=IMAGE_SIZE)) | |
active_dataset = ActiveDataset(train_dataset, pool_dataset) | |
if feature_dict is None: | |
feature_dict = get_feature_dict(batch_size, device, active_dataset) | |
active_selector = KMeanSelector( | |
batch_size=4, | |
num_workers=1, | |
pin_memory=True, | |
metric="l2", | |
feature_dict=feature_dict, | |
loaded_feature_weight=loaded_feature_weight, | |
sharp_factor=sharp_factor, | |
loaded_feature_only=loaded_feature_only, | |
) | |
load_specialist_model(model_ckpt) | |
return active_selector.select_next_batch(active_dataset, budget, specialist_model, device) | |
def build_input_ui(): | |
with gr.Accordion("Input") as blk: | |
with gr.Row(): | |
train_gallery = gr.Gallery( | |
label="Train set", allow_preview=False, columns=IMAGES_PER_ROW // 2, show_label=True | |
) | |
pool_gallery = gr.Gallery( | |
label="Pool set", allow_preview=False, columns=IMAGES_PER_ROW // 2, show_label=True | |
) | |
def gallery_change(image_list, target_set=None): | |
global feature_dict | |
if image_list is None: | |
return | |
if target_set == "train": | |
global train_set | |
train_set = [x[0] for x in image_list] | |
feature_dict = None | |
elif target_set == "pool": | |
global pool_set | |
pool_set = [x[0] for x in image_list] | |
feature_dict = None | |
train_gallery.change(partial(gallery_change, target_set="train"), train_gallery, None) | |
pool_gallery.change(partial(gallery_change, target_set="pool"), pool_gallery, None) | |
return blk | |
def build_parameters_ui(): | |
with gr.Accordion() as blk: | |
budget_input = gr.Number(config.budget, label="Budget") | |
with gr.Row(): | |
model_ckpt_file = gr.File(visible=False) | |
model_ckpt_input = gr.UploadButton(label="Upload model checkpoint") | |
device_input = gr.Dropdown(choices=["cuda", "cpu"], value="cpu", label="Device", interactive=True) | |
batch_size_input = gr.Number(config.batch_size, label="Batch Size") | |
foundation_model_weight_input = gr.Number(config.loaded_feature_weight, label="foundation_model_weight") | |
sharp_factor_input = gr.Number(config.sharp_factor, label="sharp_factor") | |
def budget_input_change(x): | |
config.budget = int(x) | |
budget_input.change(budget_input_change, budget_input, None) | |
def model_ckpt_input_upload(x): | |
config.model_ckpt = x | |
return gr.File(value=x, visible=True) | |
model_ckpt_input.upload(model_ckpt_input_upload, model_ckpt_input, model_ckpt_file) | |
def device_input_change(x): | |
config.device = torch.device(x) | |
device_input.change(device_input_change, device_input, None) | |
def batch_size_input_change(x): | |
config.batch_size = int(x) | |
batch_size_input.change(batch_size_input_change, batch_size_input, None) | |
def foundation_model_weight_input_change(x): | |
config.loaded_feature_weight = x | |
foundation_model_weight_input.change(foundation_model_weight_input_change, foundation_model_weight_input, None) | |
def sharp_factor_input_change(x): | |
config.sharp_factor = x | |
sharp_factor_input.change(sharp_factor_input_change, sharp_factor_input, None) | |
return blk | |
class_color_map = { | |
1: "#ff0000", | |
2: "#00ff00", | |
} | |
selected_image = None | |
selected_set = [] | |
annotated_set = [] | |
def predict_pseudo_label(image_pil): | |
image = F.to_tensor(image_pil) | |
image = image.unsqueeze(0) | |
_, _, H, W = image.shape | |
image = specialist_processor.preprocess(image) | |
with torch.no_grad(): | |
pred = specialist_model(image) | |
pseudo_label = pred.argmax(1) | |
pseudo_label = specialist_processor.postprocess(pseudo_label, [H, W]) | |
return pseudo_label[0] | |
def hex_to_rgb(h): | |
h = h[1:] | |
return [int(h[i : i + 2], 16) for i in range(0, 6, 2)] | |
def build_active_selection_ui(): | |
with gr.Accordion("Active Selection") as blk: | |
select_button = gr.Button("Select") | |
with gr.Row(): | |
selected_gallary = gr.Gallery( | |
label="Selected samples", allow_preview=False, columns=IMAGES_PER_ROW // 2, show_label=True | |
) | |
annotated_gallary = gr.Gallery( | |
label="Annotated samples", | |
allow_preview=True, | |
columns=IMAGES_PER_ROW // 2, | |
show_label=True, | |
interactive=False, | |
) | |
image_editor = gr.ImageEditor( | |
label="Image Editor", | |
interactive=True, | |
sources=(), | |
brush=gr.Brush(colors=[c for c in class_color_map.values()], color_mode="fixed"), | |
layers=False, | |
) | |
accept_button = gr.Button("Accept") | |
download_button = gr.DownloadButton(label="Download Annotated Dataset", visible=False) | |
def select_button_click(): | |
global selected_set, current_dataset, train_set, pool_set, config, annotated_set | |
annotated_samples = [x["path"] for x in annotated_set] | |
selected_set = active_select( | |
list(set(train_set + annotated_samples)), | |
pool_set, | |
config.budget, | |
config.model_ckpt, | |
config.batch_size, | |
config.device, | |
config.loaded_feature_weight, | |
config.sharp_factor, | |
config.loaded_feature_only, | |
) | |
current_dataset = uuid.uuid4() | |
return selected_set | |
select_button.click(select_button_click, None, selected_gallary) | |
def get_editor_value(image_path): | |
image_pil = Image.open(image_path).convert("L") | |
background = np.array(image_pil.convert("RGBA")) | |
pseudo_label = predict_pseudo_label(image_pil).cpu().numpy() | |
layer = np.zeros_like(background) | |
for cl, color in class_color_map.items(): | |
bin_mask = pseudo_label == cl | |
layer[bin_mask] = hex_to_rgb(color) + [255] | |
return {"background": background, "layers": [layer], "composite": None} | |
def gallery_select(data: gr.SelectData): | |
global selected_image | |
selected_image = { | |
"index": data.index, | |
"path": data.value["image"]["path"], | |
} | |
return get_editor_value(selected_image["path"]) | |
selected_gallary.select(gallery_select, None, image_editor) | |
def accept_button_click(value): | |
global selected_set, selected_image, annotated_set | |
if len(value["layers"]) and selected_image: | |
layer_np = value["layers"][0] | |
binary_layer_np = np.zeros_like(layer_np) | |
binary_layer_np[layer_np > 127] = 255 | |
H, W, _ = layer_np.shape | |
mask_np = np.zeros((H, W), dtype=np.uint8) | |
for cl, color in class_color_map.items(): | |
color_rgb = hex_to_rgb(color) | |
bin_mask = np.all(binary_layer_np[:, :, :3] == color_rgb, axis=-1) | |
mask_np[bin_mask] = cl | |
selected_image["image"] = value["background"] | |
selected_image["mask"] = mask_np | |
image_pil = F.to_pil_image(value["background"]).convert("RGB") | |
selected_image["visual"] = draw_mask(image_pil, mask_np) | |
selected_set = [deepcopy(x) for x in selected_set if x != selected_image["path"]] | |
annotated_set.append(deepcopy(selected_image)) | |
new_index = min(selected_image["index"], len(selected_set) - 1) | |
if new_index >= 0: | |
selected_image = {"index": new_index, "path": selected_set[new_index]} | |
image_editor = get_editor_value(selected_image["path"]) | |
else: | |
selected_image = None | |
image_editor = None | |
else: | |
image_editor = None | |
_download_button = gr.DownloadButton(value=create_download_dataset(), visible=True) | |
return image_editor, selected_set, [x["visual"] for x in annotated_set], _download_button | |
accept_button.click( | |
accept_button_click, image_editor, [image_editor, selected_gallary, annotated_gallary, download_button] | |
) | |
return blk | |
def create_download_dataset(): | |
dataset_dir = DATA_DIR / "dataset" | |
if dataset_dir.exists(): | |
shutil.rmtree(dataset_dir) | |
dataset_dir.mkdir(exist_ok=True, parents=True) | |
images_dir = dataset_dir / "images" | |
labels_dir = dataset_dir / "labels" | |
images_dir.mkdir(exist_ok=True, parents=True) | |
labels_dir.mkdir(exist_ok=True, parents=True) | |
zip_file = DATA_DIR / "dataset.zip" | |
with zipfile.ZipFile(zip_file, "w") as archive: | |
for sample in annotated_set: | |
case_name = Path(sample["path"]).stem | |
image_np = sample["image"] | |
label_np = sample["mask"] | |
image_pil = Image.fromarray(image_np) | |
label_pil = Image.fromarray(label_np) | |
image_pil.save(images_dir / f"{case_name}.png") | |
label_pil.save(labels_dir / f"{case_name}.png") | |
archive.write(images_dir / f"{case_name}.png", arcname=f"images/{case_name}.png") | |
archive.write(labels_dir / f"{case_name}.png", arcname=f"labels/{case_name}.png") | |
return zip_file | |
if __name__ == "__main__": | |
with gr.Blocks() as demo: | |
input_ui = build_input_ui() | |
parameters_ui = build_parameters_ui() | |
active_selection_ui = build_active_selection_ui() | |
demo.launch(inbrowser=True) | |