fatchecker / app.py
bumble-bee's picture
added unet files
6bf4d42
import gradio as gr
from gradio_client import Client
import numpy as np
#import torch
import requests
from PIL import Image
#from torchvision import transforms
from predict_unet import predict_model
title = "<center><strong><font size='8'> Medical Image Segmentation with UNet </font></strong></center>"
examples = [["examples/50494616.jpg"], ["examples/50494676.jpg"], ["examples/56399783.jpg"],
["examples/56399789.jpg"], ["examples/56399831.jpg"], ["examples/56399959.jpg"],
["examples/56400014.jpg"], ["examples/56400119.jpg"],
["examples/56481903.jpg"], ["examples/70749195.jpg"]]
def run_unetv0(input):
output = predict_model(input, "v0")
normalized_output = np.clip(output, 0, 1)
return normalized_output
def run_unetv1(input):
output = predict_model(input, "v1")
normalized_output = np.clip(output, 0, 1)
return normalized_output
def run_unetv2(input):
output = predict_model(input, "v2")
normalized_output = np.clip(output, 0, 1)
return normalized_output
def run_unetv3(input):
output = predict_model(input, "v3")
normalized_output = np.clip(output, 0, 1)
return normalized_output
input_img_v0 = gr.Image(label="Input", type='numpy')
segm_img_v0 = gr.Image(label="Segmented Image")
input_img_v1 = gr.Image(label="Input", type='numpy')
segm_img_v1 = gr.Image(label="Segmented Image")
input_img_v2 = gr.Image(label="Input", type='numpy')
segm_img_v2 = gr.Image(label="Segmented Image")
input_img_v3 = gr.Image(label="Input", type='numpy')
segm_img_v3 = gr.Image(label="Segmented Image")
with gr.Blocks(title='UNet examples') as demo:
# v0: regular UNet
with gr.Tab("Regular UNet (v0)"):
# display input image and segmented image
with gr.Row(variant="panel"):
with gr.Column(scale=1):
input_img_v0.render()
with gr.Column(scale=1):
segm_img_v0.render()
# submit and clear
with gr.Row():
with gr.Column():
segment_btn_v0 = gr.Button("Run Segmentation", variant='primary')
clear_btn_v0 = gr.Button("Clear", variant="secondary")
# load examples
gr.Markdown("Try some of the examples below")
gr.Examples(examples=examples,
inputs=[input_img_v0],
outputs=segm_img_v0,
fn=run_unetv0,
cache_examples=False,
examples_per_page=5)
# just a placeholder for second column
with gr.Column():
gr.Markdown("")
segment_btn_v0.click(run_unetv0,
inputs=[
input_img_v0,
],
outputs=segm_img_v0)
# v1: UNet3+
with gr.Tab("UNet3+ (v1)"):
# display input image and segmented image
with gr.Row(variant="panel"):
with gr.Column(scale=1):
input_img_v1.render()
with gr.Column(scale=1):
segm_img_v1.render()
# submit and clear
with gr.Row():
with gr.Column():
segment_btn_v1 = gr.Button("Run Segmentation", variant='primary')
clear_btn_v1 = gr.Button("Clear", variant="secondary")
# load examples
gr.Markdown("Try some of the examples below")
gr.Examples(examples=examples,
inputs=[input_img_v1],
outputs=segm_img_v1,
fn=run_unetv1,
cache_examples=False,
examples_per_page=5)
# just a placeholder for second column
with gr.Column():
gr.Markdown("")
segment_btn_v1.click(run_unetv1,
inputs=[
input_img_v1,
],
outputs=segm_img_v1)
# v2: UNet3+ with deep supervision
with gr.Tab("UNet3+(v2) with deep supervision"):
# display input image and segmented image
with gr.Row(variant="panel"):
with gr.Column(scale=1):
input_img_v2.render()
with gr.Column(scale=1):
segm_img_v2.render()
# submit and clear
with gr.Row():
with gr.Column():
segment_btn_v2 = gr.Button("Run Segmentation", variant='primary')
clear_btn_v2 = gr.Button("Clear", variant="secondary")
# load examples
gr.Markdown("Try some of the examples below")
gr.Examples(examples=examples,
inputs=[input_img_v2],
outputs=segm_img_v2,
fn=run_unetv2,
cache_examples=False,
examples_per_page=5)
# just a placeholder for second column
with gr.Column():
gr.Markdown("")
segment_btn_v2.click(run_unetv2,
inputs=[
input_img_v2,
],
outputs=segm_img_v2)
# v3: UNet3+ with deep supervision and cgm
with gr.Tab("UNet3+(v3) with deep supervision and cgm"):
# display input image and segmented image
with gr.Row(variant="panel"):
with gr.Column(scale=1):
input_img_v3.render()
with gr.Column(scale=1):
segm_img_v3.render()
# submit and clear
with gr.Row():
with gr.Column():
segment_btn_v3 = gr.Button("Run Segmentation", variant='primary')
clear_btn_v3 = gr.Button("Clear", variant="secondary")
# load examples
gr.Markdown("Try some of the examples below")
gr.Examples(examples=examples,
inputs=[input_img_v3],
outputs=segm_img_v3,
fn=run_unetv3,
cache_examples=False,
examples_per_page=5)
# just a placeholder for second column
with gr.Column():
gr.Markdown("")
segment_btn_v3.click(run_unetv3,
inputs=[
input_img_v3,
],
outputs=segm_img_v3)
def clear():
return None, None
clear_btn_v0.click(clear, outputs=[input_img_v0, segm_img_v0])
clear_btn_v1.click(clear, outputs=[input_img_v1, segm_img_v1])
clear_btn_v2.click(clear, outputs=[input_img_v2, segm_img_v2])
clear_btn_v3.click(clear, outputs=[input_img_v3, segm_img_v3])
demo.queue()
demo.launch()