ariankhalfani's picture
Update app.py
48dc356 verified
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import cv2
import tensorflow as tf
import gradio as gr
import io
def load_model(model_path):
model = tf.keras.models.load_model(model_path)
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.BinaryCrossentropy(), metrics=['accuracy'])
return model
def get_model_summary(model):
stream = io.StringIO()
model.summary(print_fn=lambda x: stream.write(x + "\n"))
summary_str = stream.getvalue()
stream.close()
return summary_str
def get_input_shape(model):
input_shape = model.input_shape[1:] # Skip the batch dimension
return input_shape
def preprocess_image(image, input_shape):
img = np.array(image)
num_channels = input_shape[-1]
if num_channels == 1: # Model expects grayscale
if len(img.shape) == 2: # Image is already grayscale
img = np.expand_dims(img, axis=-1)
elif img.shape[2] == 3: # Convert RGB to grayscale
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = np.expand_dims(img, axis=-1)
elif num_channels == 3: # Model expects RGB
if len(img.shape) == 2: # Convert grayscale to RGB
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 1: # Convert single channel to RGB
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
img_resized = cv2.resize(img, (input_shape[0], input_shape[1]))
img_normalized = img_resized / 255.0
img_batch = np.expand_dims(img_normalized, axis=0)
return img_batch
def diagnose_image(image, model, input_shape):
img_batch = preprocess_image(image, input_shape)
prediction = model.predict(img_batch)
glaucoma_probability = prediction[0][0]
result_text = f"Probability of glaucoma: {glaucoma_probability:.2%}"
img_display = np.array(image)
if len(img_display.shape) == 2 or img_display.shape[2] == 1: # Convert to RGB for display
img_display = cv2.cvtColor(img_display, cv2.COLOR_GRAY2RGB)
image_pil = Image.fromarray(img_display)
draw = ImageDraw.Draw(image_pil)
font = ImageFont.load_default()
text = f"{glaucoma_probability:.2%}"
text_size = draw.textsize(text, font=font)
rect_width = 200
rect_height = 100
rect_x = (image_pil.width - rect_width) // 2
rect_y = (image_pil.height - rect_height) // 2
draw.rectangle([rect_x, rect_y, rect_x + rect_width, rect_y + rect_height], outline="red", width=3)
text_x = rect_x + (rect_width - text_size[0]) // 2
text_y = rect_y + (rect_height - text_size[1]) // 2
draw.text((text_x, text_y), text, fill="red", font=font)
return image_pil, result_text
def main():
with gr.Blocks() as demo:
gr.Markdown("# Glaucoma Detection App")
gr.Markdown("Upload a fundus eye image to detect the probability of glaucoma.")
with gr.Row():
model_file = gr.File(label="Upload Model (.h5 or .keras)")
load_model_btn = gr.Button("Load Model")
model_info = gr.Markdown()
image = gr.Image(type="pil", label="Upload Image")
submit_btn = gr.Button("Diagnose")
result = gr.Textbox(label="Diagnosis Result")
def load_and_display_model_info(file):
model = load_model(file.name)
model_summary = get_model_summary(model)
input_shape = get_input_shape(model)
return model, model_summary, input_shape
model = gr.State(None)
input_shape = gr.State(None)
def diagnose_and_display(image, model, input_shape):
diagnosis_image, result_text = diagnose_image(image, model, input_shape)
return diagnosis_image, result_text
load_model_btn.click(fn=load_and_display_model_info, inputs=model_file, outputs=[model, model_info, input_shape])
submit_btn.click(fn=diagnose_and_display, inputs=[image, model, input_shape], outputs=[image, result])
gr.Markdown("### Glaucoma Analyzer V.1.0.0 by Thariq Arian")
demo.launch()
if __name__ == "__main__":
main()