AI-RESEARCHER-2024's picture
Update app.py
01ca724 verified
import gradio as gr
import numpy as np
import tensorflow as tf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
import cv2
from tensorflow.keras.models import load_model
from PIL import Image
import os
import pickle
from tensorflow.keras import backend as K
# I/O image dimensions
DISPLAY_DIMS = (256, 256) # For display
CLASS_DIMS = (224, 224) # For classification model input
SEG_DIMS = (128, 128) # For segmentation model input
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define Dice Coefficient function for TensorFlow segmentation model
def dice_coefficient(y_true, y_pred, smooth=1):
y_true_f = K.flatten(tf.cast(y_true, tf.float32))
y_pred_f = K.flatten(tf.cast(y_pred, tf.float32))
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
# Define Classification Model (PyTorch)
class ClassificationModel(nn.Module):
def __init__(self, input_dim):
super(ClassificationModel, self).__init__()
self.fc1 = nn.Linear(input_dim, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 16)
self.fc4 = nn.Linear(16, 2) # Binary Classification
self.dropout = nn.Dropout(0.3)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
x = F.relu(self.fc3(x))
x = self.fc4(x)
return x
# Load models
try:
# Load ResNet feature extractor
resnet = models.resnet18(pretrained=True)
resnet = nn.Sequential(*list(resnet.children())[:-1]) # Remove FC layer
resnet.to(device)
resnet.eval()
# Load Feature Selector
with open("feature_selector.pkl", "rb") as f:
selector = pickle.load(f)
# Load Classification Model
input_dim = selector.get_support().sum() # Number of selected features
classification_model = ClassificationModel(input_dim).to(device)
classification_model.load_state_dict(torch.load("trained_model.pth", map_location=device))
classification_model.eval()
# Image transformation for PyTorch model
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load segmentation model
segmentation_model = None
if os.path.exists("segmentation_model.h5"):
segmentation_model = load_model("segmentation_model.h5",
custom_objects={'dice_coefficient': dice_coefficient},
compile=False)
print("Loaded segmentation_model.h5")
elif os.path.exists("best_model.keras"):
segmentation_model = load_model("best_model.keras",
custom_objects={'dice_coefficient': dice_coefficient},
compile=False)
print("Loaded best_model.keras")
models_loaded = True
print("Models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
print("The app will run in demo mode with simulated predictions.")
models_loaded = False
resnet = None
selector = None
classification_model = None
segmentation_model = None
transform = None
# Function to preprocess image for classification
def preprocess_for_classification(image):
if not isinstance(image, Image.Image):
image = Image.fromarray(np.array(image))
image = image.convert("RGB") # Ensure RGB
return transform(image).unsqueeze(0).to(device)
# Function to preprocess image for segmentation
def preprocess_for_segmentation(image):
if isinstance(image, Image.Image):
image = np.array(image)
# Convert to RGB if needed
if len(image.shape) == 2: # Grayscale
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4: # RGBA
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Resize to segmentation model's input size
image = cv2.resize(image, SEG_DIMS)
# Normalize
image = image / 255.0
# Add batch dimension
image = np.expand_dims(image, axis=0)
return image
# Function to classify COVID-19 using PyTorch model
def classify_image(image):
if image is None:
return "No image provided", None, 0
try:
if models_loaded and resnet is not None and classification_model is not None:
# Preprocess and extract features
img_tensor = preprocess_for_classification(image)
with torch.no_grad():
features = resnet(img_tensor).view(-1).cpu().numpy()
# Select features using the feature selector
features_selected = selector.transform(features.reshape(1, -1))
input_tensor = torch.tensor(features_selected, dtype=torch.float32).to(device)
# Make prediction
with torch.no_grad():
output = classification_model(input_tensor)
print("Classification output:",output)
predicted_class = torch.argmax(output, dim=1).item()
print("Classification predicted class:",predicted_class)
probabilities = F.softmax(output, dim=1)
print("Classification probabilities:",probabilities)
confidence = probabilities[0][predicted_class].item()
# Map class index to label (0 -> COVID, 1 -> Non-COVID)
status = "COVID" if predicted_class == 0 else "Non-COVID"
return f"Predicted: {status} (Class: {predicted_class}, Confidence: {confidence:.2f})", image, predicted_class
else:
# Demo mode with simulated predictions
import random
predicted_class = random.randint(0, 1) # 0 or 1
confidence = random.uniform(0.7, 0.99)
status = "COVID" if predicted_class == 0 else "Non-COVID"
return f"Predicted: {status} (Class: {predicted_class}, Confidence: {confidence:.2f}) [DEMO]", image, predicted_class
except Exception as e:
return f"Error during classification: {str(e)}", image, 0
# Function to segment lesions in CT images
def segment_image(image):
if image is None:
return "No segmentation performed", None, None
try:
if models_loaded and segmentation_model is not None:
# Preprocess for segmentation
input_image = preprocess_for_segmentation(image)
# Predict mask
pred_mask = segmentation_model.predict(input_image)
binary_mask = (pred_mask > 0.5).astype(np.uint8)
# Create colored overlay
if isinstance(image, Image.Image):
display_image = np.array(image)
else:
display_image = np.array(image)
# Resize original image for display
display_image = cv2.resize(display_image, DISPLAY_DIMS)
# Resize predicted mask to match display image
display_mask = cv2.resize(binary_mask[0].squeeze(), DISPLAY_DIMS)
# Create overlay
overlay = display_image.copy()
if len(overlay.shape) == 2: # If grayscale
overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2RGB)
elif overlay.shape[2] == 4: # If RGBA
overlay = cv2.cvtColor(overlay, cv2.COLOR_RGBA2RGB)
# Apply red mask on segmented areas
overlay[:, :, 0] = np.maximum(overlay[:, :, 0], display_mask * 255) # Red channel
overlay[:, :, 1] = np.where(display_mask > 0, overlay[:, :, 1] * 0.5, overlay[:, :, 1]) # Reduce green
overlay[:, :, 2] = np.where(display_mask > 0, overlay[:, :, 2] * 0.5, overlay[:, :, 2]) # Reduce blue
# Calculate lesion percentage
lesion_percentage = np.sum(binary_mask) / binary_mask.size * 100
# Enhance the segmentation mask for visibility
# Convert to 3-channel image with a heatmap colormap
enhanced_mask = cv2.normalize(display_mask, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
enhanced_mask = cv2.applyColorMap(enhanced_mask, cv2.COLORMAP_JET) # Apply color map for visibility
# return f"Lesion Coverage: {lesion_percentage:.2f}%", enhanced_mask, overlay
return enhanced_mask, overlay
else:
# Demo mode with simulated segmentation
return simulate_segmentation(image)
except Exception as e:
return f"Error during segmentation: {str(e)}", None, image
# Function to simulate segmentation for demo mode
def simulate_segmentation(image):
# For demo mode, create a simulated segmentation
import random
if isinstance(image, Image.Image):
display_image = np.array(image)
else:
display_image = np.array(image)
if len(display_image.shape) == 2:
display_image = cv2.cvtColor(display_image, cv2.COLOR_GRAY2RGB)
elif display_image.shape[2] == 4:
display_image = cv2.cvtColor(display_image, cv2.COLOR_RGBA2RGB)
display_image = cv2.resize(display_image, DISPLAY_DIMS)
# Create a blank mask
mask = np.zeros(DISPLAY_DIMS, dtype=np.uint8)
# Simulate random blobs
num_blobs = random.randint(1, 3)
for i in range(num_blobs):
center_x = random.randint(50, DISPLAY_DIMS[0]-50)
center_y = random.randint(50, DISPLAY_DIMS[1]-50)
radius = random.randint(10, 30)
cv2.circle(mask, (center_x, center_y), radius, 1, -1)
# Create colored overlay
overlay = display_image.copy()
# Apply red mask on segmented areas
overlay[:, :, 0] = np.maximum(overlay[:, :, 0], mask * 255) # Red channel
overlay[:, :, 1] = np.where(mask > 0, overlay[:, :, 1] * 0.5, overlay[:, :, 1]) # Reduce green
overlay[:, :, 2] = np.where(mask > 0, overlay[:, :, 2] * 0.5, overlay[:, :, 2]) # Reduce blue
# Enhance the mask for visibility
enhanced_mask = cv2.normalize(mask, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
enhanced_mask = cv2.applyColorMap(enhanced_mask, cv2.COLORMAP_JET) # Apply color map for visibility
lesion_percentage = np.sum(mask) / mask.size * 100
# return f"Lesion Coverage: {lesion_percentage:.2f}% [DEMO]", enhanced_mask, overlay
return enhanced_mask, overlay
# Function to run both classification and segmentation
def process_image(image):
if image is None:
return None, "No image provided", None, "No image provided"
# Run classification
classification_result, processed_image, predicted_class = classify_image(image)
# Run segmentation (now for all images regardless of class)
# segmentation_result, segmentation_map, overlay_image = segment_image(image)
segmentation_map, overlay_image = segment_image(image)
# Combine results
# combined_result = f"{classification_result}\n{segmentation_result}"
# return overlay_image, combined_result, segmentation_map, classification_result
return overlay_image, classification_result, segmentation_map, classification_result
# Load example images
def load_covid_examples():
examples = []
try:
# Look for COVID example images
for i in range(1, 6):
covid_path = f"./examples/Covid ({i}).png"
if os.path.exists(covid_path):
examples.append([covid_path])
# If no COVID examples were found, create placeholders
if len(examples) == 0:
for i in range(1, 6):
covid_img = np.ones((256, 256, 3), dtype=np.uint8) * 200
cv2.putText(covid_img, f"COVID Example {i}", (30, 128),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (100, 100, 100), 2)
examples.append([covid_img])
except Exception as e:
print(f"Could not load COVID examples: {e}")
return examples
def load_non_covid_examples():
examples = []
try:
# Look for Non-COVID example images
for i in range(1, 6):
non_covid_path = f"./examples/Non-Covid ({i}).png"
if os.path.exists(non_covid_path):
examples.append([non_covid_path])
# If no Non-COVID examples were found, create placeholders
if len(examples) == 0:
for i in range(1, 6):
non_covid_img = np.ones((256, 256, 3), dtype=np.uint8) * 200
cv2.putText(non_covid_img, f"Non-COVID Example {i}", (30, 128),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (100, 100, 100), 2)
examples.append([non_covid_img])
except Exception as e:
print(f"Could not load Non-COVID examples: {e}")
return examples
class GradioInterface:
def __init__(self):
self.covid_examples = load_covid_examples()
self.non_covid_examples = load_non_covid_examples()
def create_interface(self):
app_styles = """
<style>
/* Global Styles */
body, #root {
font-family: Helvetica, Arial, sans-serif;
background-color: #1a1a1a;
color: #fafafa;
}
/* Header Styles */
.app-header {
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
padding: 24px;
border-radius: 8px;
margin-bottom: 24px;
text-align: center;
}
.app-title {
font-size: 48px;
margin: 0;
color: #fafafa;
}
.app-subtitle {
font-size: 24px;
margin: 8px 0 16px;
color: #fafafa;
}
.app-description {
font-size: 16px;
line-height: 1.6;
opacity: 0.8;
margin-bottom: 24px;
}
/* Button Styles */
.publication-links {
display: flex;
justify-content: center;
flex-wrap: wrap;
gap: 8px;
margin-bottom: 16px;
}
.publication-link {
display: inline-flex;
align-items: center;
padding: 8px 16px;
background-color: #333;
color: #fff !important;
text-decoration: none !important;
border-radius: 20px;
font-size: 14px;
transition: background-color 0.3s;
}
.publication-link:hover {
background-color: #555;
}
.publication-link i {
margin-right: 8px;
}
/* Content Styles */
.content-container {
background-color: #2a2a2a;
border-radius: 8px;
padding: 24px;
margin-bottom: 24px;
}
/* Image Styles */
.image-preview img {
max-width: 256px;
max-height: 256px;
margin: 0 auto;
border-radius: 4px;
display: block;
object-fit: contain;
}
/* Control Styles */
.control-panel {
background-color: #333;
padding: 16px;
border-radius: 8px;
margin-top: 16px;
}
/* Gradio Component Overrides */
.gr-button {
background-color: #4a4a4a;
color: #fff;
border: none;
border-radius: 4px;
padding: 8px 16px;
cursor: pointer;
transition: background-color 0.3s;
}
.gr-button:hover {
background-color: #5a5a5a;
}
.gr-input, .gr-dropdown {
background-color: #3a3a3a;
color: #fff;
border: 1px solid #4a4a4a;
border-radius: 4px;
padding: 8px;
}
.gr-form {
background-color: transparent;
}
.gr-panel {
border: none;
background-color: transparent;
}
</style>
"""
header_html = f"""
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
{app_styles}
<div class="app-header">
<h1 class="app-title">COVID-19 CT Analysis System</h1>
<h2 class="app-subtitle">Classification & Lesion Segmentation</h2>
<p class="app-description">
Upload CT scan images to detect COVID-19 and segment lesions if present.
The system uses ResNet-18 for feature extraction and a U-Net for lesion segmentation.
</p>
</div>
"""
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo:
gr.HTML(header_html)
with gr.Row(elem_classes="content-container"):
with gr.Column():
input_image = gr.Image(label="Upload CT Scan Image", type="pil", image_mode="RGB", elem_classes="image-preview")
run_button = gr.Button("Analyze Image", elem_classes="gr-button")
with gr.Row():
with gr.Column(scale=1):
covid_examples_title = gr.Markdown("### COVID Examples")
covid_examples = gr.Examples(
examples=self.covid_examples,
inputs=input_image,
label=""
)
with gr.Column(scale=1):
non_covid_examples_title = gr.Markdown("### Non-COVID Examples")
non_covid_examples = gr.Examples(
examples=self.non_covid_examples,
inputs=input_image,
label=""
)
with gr.Column():
with gr.Tab("Results"):
overlay_image = gr.Image(label="Segmentation Overlay", elem_classes="image-preview")
result_text = gr.Textbox(label="Analysis Results")
with gr.Tab("Segmentation Details"):
segmentation_image = gr.Image(label="Lesion Segmentation Map", elem_classes="image-preview")
classification_text = gr.Textbox(label="Classification Details")
run_button.click(
fn=process_image,
inputs=input_image,
outputs=[overlay_image, result_text, segmentation_image, classification_text],
)
return demo
def main():
interface = GradioInterface()
demo = interface.create_interface()
demo.launch(share=True)
if __name__ == "__main__":
main()