Spaces:
Sleeping
Sleeping
import gradio as gr | |
from gradio_client import Client | |
import numpy as np | |
#import torch | |
import requests | |
from PIL import Image | |
#from torchvision import transforms | |
from predict_unet import predict_model | |
title = "<center><strong><font size='8'> Medical Image Segmentation with UNet </font></strong></center>" | |
examples = [["examples/50494616.jpg"], ["examples/50494676.jpg"], ["examples/56399783.jpg"], | |
["examples/56399789.jpg"], ["examples/56399831.jpg"], ["examples/56399959.jpg"], | |
["examples/56400014.jpg"], ["examples/56400119.jpg"], | |
["examples/56481903.jpg"], ["examples/70749195.jpg"]] | |
def run_unetv0(input): | |
output = predict_model(input, "v0") | |
normalized_output = np.clip(output, 0, 1) | |
return normalized_output | |
def run_unetv1(input): | |
output = predict_model(input, "v1") | |
normalized_output = np.clip(output, 0, 1) | |
return normalized_output | |
def run_unetv2(input): | |
output = predict_model(input, "v2") | |
normalized_output = np.clip(output, 0, 1) | |
return normalized_output | |
def run_unetv3(input): | |
output = predict_model(input, "v3") | |
normalized_output = np.clip(output, 0, 1) | |
return normalized_output | |
input_img_v0 = gr.Image(label="Input", type='numpy') | |
segm_img_v0 = gr.Image(label="Segmented Image") | |
input_img_v1 = gr.Image(label="Input", type='numpy') | |
segm_img_v1 = gr.Image(label="Segmented Image") | |
input_img_v2 = gr.Image(label="Input", type='numpy') | |
segm_img_v2 = gr.Image(label="Segmented Image") | |
input_img_v3 = gr.Image(label="Input", type='numpy') | |
segm_img_v3 = gr.Image(label="Segmented Image") | |
with gr.Blocks(title='UNet examples') as demo: | |
# v0: regular UNet | |
with gr.Tab("Regular UNet (v0)"): | |
# display input image and segmented image | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
input_img_v0.render() | |
with gr.Column(scale=1): | |
segm_img_v0.render() | |
# submit and clear | |
with gr.Row(): | |
with gr.Column(): | |
segment_btn_v0 = gr.Button("Run Segmentation", variant='primary') | |
clear_btn_v0 = gr.Button("Clear", variant="secondary") | |
# load examples | |
gr.Markdown("Try some of the examples below") | |
gr.Examples(examples=examples, | |
inputs=[input_img_v0], | |
outputs=segm_img_v0, | |
fn=run_unetv0, | |
cache_examples=False, | |
examples_per_page=5) | |
# just a placeholder for second column | |
with gr.Column(): | |
gr.Markdown("") | |
segment_btn_v0.click(run_unetv0, | |
inputs=[ | |
input_img_v0, | |
], | |
outputs=segm_img_v0) | |
# v1: UNet3+ | |
with gr.Tab("UNet3+ (v1)"): | |
# display input image and segmented image | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
input_img_v1.render() | |
with gr.Column(scale=1): | |
segm_img_v1.render() | |
# submit and clear | |
with gr.Row(): | |
with gr.Column(): | |
segment_btn_v1 = gr.Button("Run Segmentation", variant='primary') | |
clear_btn_v1 = gr.Button("Clear", variant="secondary") | |
# load examples | |
gr.Markdown("Try some of the examples below") | |
gr.Examples(examples=examples, | |
inputs=[input_img_v1], | |
outputs=segm_img_v1, | |
fn=run_unetv1, | |
cache_examples=False, | |
examples_per_page=5) | |
# just a placeholder for second column | |
with gr.Column(): | |
gr.Markdown("") | |
segment_btn_v1.click(run_unetv1, | |
inputs=[ | |
input_img_v1, | |
], | |
outputs=segm_img_v1) | |
# v2: UNet3+ with deep supervision | |
with gr.Tab("UNet3+(v2) with deep supervision"): | |
# display input image and segmented image | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
input_img_v2.render() | |
with gr.Column(scale=1): | |
segm_img_v2.render() | |
# submit and clear | |
with gr.Row(): | |
with gr.Column(): | |
segment_btn_v2 = gr.Button("Run Segmentation", variant='primary') | |
clear_btn_v2 = gr.Button("Clear", variant="secondary") | |
# load examples | |
gr.Markdown("Try some of the examples below") | |
gr.Examples(examples=examples, | |
inputs=[input_img_v2], | |
outputs=segm_img_v2, | |
fn=run_unetv2, | |
cache_examples=False, | |
examples_per_page=5) | |
# just a placeholder for second column | |
with gr.Column(): | |
gr.Markdown("") | |
segment_btn_v2.click(run_unetv2, | |
inputs=[ | |
input_img_v2, | |
], | |
outputs=segm_img_v2) | |
# v3: UNet3+ with deep supervision and cgm | |
with gr.Tab("UNet3+(v3) with deep supervision and cgm"): | |
# display input image and segmented image | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
input_img_v3.render() | |
with gr.Column(scale=1): | |
segm_img_v3.render() | |
# submit and clear | |
with gr.Row(): | |
with gr.Column(): | |
segment_btn_v3 = gr.Button("Run Segmentation", variant='primary') | |
clear_btn_v3 = gr.Button("Clear", variant="secondary") | |
# load examples | |
gr.Markdown("Try some of the examples below") | |
gr.Examples(examples=examples, | |
inputs=[input_img_v3], | |
outputs=segm_img_v3, | |
fn=run_unetv3, | |
cache_examples=False, | |
examples_per_page=5) | |
# just a placeholder for second column | |
with gr.Column(): | |
gr.Markdown("") | |
segment_btn_v3.click(run_unetv3, | |
inputs=[ | |
input_img_v3, | |
], | |
outputs=segm_img_v3) | |
def clear(): | |
return None, None | |
clear_btn_v0.click(clear, outputs=[input_img_v0, segm_img_v0]) | |
clear_btn_v1.click(clear, outputs=[input_img_v1, segm_img_v1]) | |
clear_btn_v2.click(clear, outputs=[input_img_v2, segm_img_v2]) | |
clear_btn_v3.click(clear, outputs=[input_img_v3, segm_img_v3]) | |
demo.queue() | |
demo.launch() | |