|
import gradio as gr |
|
import numpy as np |
|
import tensorflow as tf |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision import models, transforms |
|
import cv2 |
|
from tensorflow.keras.models import load_model |
|
from PIL import Image |
|
import os |
|
import pickle |
|
from tensorflow.keras import backend as K |
|
|
|
|
|
DISPLAY_DIMS = (256, 256) |
|
CLASS_DIMS = (224, 224) |
|
SEG_DIMS = (128, 128) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
def dice_coefficient(y_true, y_pred, smooth=1): |
|
y_true_f = K.flatten(tf.cast(y_true, tf.float32)) |
|
y_pred_f = K.flatten(tf.cast(y_pred, tf.float32)) |
|
intersection = K.sum(y_true_f * y_pred_f) |
|
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) |
|
|
|
|
|
class ClassificationModel(nn.Module): |
|
def __init__(self, input_dim): |
|
super(ClassificationModel, self).__init__() |
|
self.fc1 = nn.Linear(input_dim, 128) |
|
self.fc2 = nn.Linear(128, 64) |
|
self.fc3 = nn.Linear(64, 16) |
|
self.fc4 = nn.Linear(16, 2) |
|
self.dropout = nn.Dropout(0.3) |
|
|
|
def forward(self, x): |
|
x = F.relu(self.fc1(x)) |
|
x = self.dropout(x) |
|
x = F.relu(self.fc2(x)) |
|
x = self.dropout(x) |
|
x = F.relu(self.fc3(x)) |
|
x = self.fc4(x) |
|
return x |
|
|
|
|
|
try: |
|
|
|
resnet = models.resnet18(pretrained=True) |
|
resnet = nn.Sequential(*list(resnet.children())[:-1]) |
|
resnet.to(device) |
|
resnet.eval() |
|
|
|
|
|
with open("feature_selector.pkl", "rb") as f: |
|
selector = pickle.load(f) |
|
|
|
|
|
input_dim = selector.get_support().sum() |
|
classification_model = ClassificationModel(input_dim).to(device) |
|
classification_model.load_state_dict(torch.load("trained_model.pth", map_location=device)) |
|
classification_model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
segmentation_model = None |
|
if os.path.exists("segmentation_model.h5"): |
|
segmentation_model = load_model("segmentation_model.h5", |
|
custom_objects={'dice_coefficient': dice_coefficient}, |
|
compile=False) |
|
print("Loaded segmentation_model.h5") |
|
elif os.path.exists("best_model.keras"): |
|
segmentation_model = load_model("best_model.keras", |
|
custom_objects={'dice_coefficient': dice_coefficient}, |
|
compile=False) |
|
print("Loaded best_model.keras") |
|
|
|
models_loaded = True |
|
print("Models loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading models: {e}") |
|
print("The app will run in demo mode with simulated predictions.") |
|
models_loaded = False |
|
resnet = None |
|
selector = None |
|
classification_model = None |
|
segmentation_model = None |
|
transform = None |
|
|
|
|
|
def preprocess_for_classification(image): |
|
if not isinstance(image, Image.Image): |
|
image = Image.fromarray(np.array(image)) |
|
image = image.convert("RGB") |
|
return transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
def preprocess_for_segmentation(image): |
|
if isinstance(image, Image.Image): |
|
image = np.array(image) |
|
|
|
|
|
if len(image.shape) == 2: |
|
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) |
|
elif image.shape[2] == 4: |
|
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
image = cv2.resize(image, SEG_DIMS) |
|
|
|
|
|
image = image / 255.0 |
|
|
|
|
|
image = np.expand_dims(image, axis=0) |
|
|
|
return image |
|
|
|
|
|
def classify_image(image): |
|
if image is None: |
|
return "No image provided", None, 0 |
|
|
|
try: |
|
if models_loaded and resnet is not None and classification_model is not None: |
|
|
|
img_tensor = preprocess_for_classification(image) |
|
|
|
with torch.no_grad(): |
|
features = resnet(img_tensor).view(-1).cpu().numpy() |
|
|
|
|
|
features_selected = selector.transform(features.reshape(1, -1)) |
|
input_tensor = torch.tensor(features_selected, dtype=torch.float32).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = classification_model(input_tensor) |
|
print("Classification output:",output) |
|
predicted_class = torch.argmax(output, dim=1).item() |
|
print("Classification predicted class:",predicted_class) |
|
probabilities = F.softmax(output, dim=1) |
|
print("Classification probabilities:",probabilities) |
|
confidence = probabilities[0][predicted_class].item() |
|
|
|
|
|
status = "COVID" if predicted_class == 0 else "Non-COVID" |
|
|
|
return f"Predicted: {status} (Class: {predicted_class}, Confidence: {confidence:.2f})", image, predicted_class |
|
else: |
|
|
|
import random |
|
predicted_class = random.randint(0, 1) |
|
confidence = random.uniform(0.7, 0.99) |
|
status = "COVID" if predicted_class == 0 else "Non-COVID" |
|
return f"Predicted: {status} (Class: {predicted_class}, Confidence: {confidence:.2f}) [DEMO]", image, predicted_class |
|
except Exception as e: |
|
return f"Error during classification: {str(e)}", image, 0 |
|
|
|
|
|
def segment_image(image): |
|
if image is None: |
|
return "No segmentation performed", None, None |
|
|
|
try: |
|
if models_loaded and segmentation_model is not None: |
|
|
|
input_image = preprocess_for_segmentation(image) |
|
|
|
|
|
pred_mask = segmentation_model.predict(input_image) |
|
binary_mask = (pred_mask > 0.5).astype(np.uint8) |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
display_image = np.array(image) |
|
else: |
|
display_image = np.array(image) |
|
|
|
|
|
display_image = cv2.resize(display_image, DISPLAY_DIMS) |
|
|
|
|
|
display_mask = cv2.resize(binary_mask[0].squeeze(), DISPLAY_DIMS) |
|
|
|
|
|
overlay = display_image.copy() |
|
if len(overlay.shape) == 2: |
|
overlay = cv2.cvtColor(overlay, cv2.COLOR_GRAY2RGB) |
|
elif overlay.shape[2] == 4: |
|
overlay = cv2.cvtColor(overlay, cv2.COLOR_RGBA2RGB) |
|
|
|
|
|
overlay[:, :, 0] = np.maximum(overlay[:, :, 0], display_mask * 255) |
|
overlay[:, :, 1] = np.where(display_mask > 0, overlay[:, :, 1] * 0.5, overlay[:, :, 1]) |
|
overlay[:, :, 2] = np.where(display_mask > 0, overlay[:, :, 2] * 0.5, overlay[:, :, 2]) |
|
|
|
|
|
lesion_percentage = np.sum(binary_mask) / binary_mask.size * 100 |
|
|
|
|
|
|
|
enhanced_mask = cv2.normalize(display_mask, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) |
|
enhanced_mask = cv2.applyColorMap(enhanced_mask, cv2.COLORMAP_JET) |
|
|
|
|
|
return enhanced_mask, overlay |
|
else: |
|
|
|
return simulate_segmentation(image) |
|
except Exception as e: |
|
return f"Error during segmentation: {str(e)}", None, image |
|
|
|
|
|
def simulate_segmentation(image): |
|
|
|
import random |
|
|
|
if isinstance(image, Image.Image): |
|
display_image = np.array(image) |
|
else: |
|
display_image = np.array(image) |
|
|
|
if len(display_image.shape) == 2: |
|
display_image = cv2.cvtColor(display_image, cv2.COLOR_GRAY2RGB) |
|
elif display_image.shape[2] == 4: |
|
display_image = cv2.cvtColor(display_image, cv2.COLOR_RGBA2RGB) |
|
|
|
display_image = cv2.resize(display_image, DISPLAY_DIMS) |
|
|
|
|
|
mask = np.zeros(DISPLAY_DIMS, dtype=np.uint8) |
|
|
|
|
|
num_blobs = random.randint(1, 3) |
|
for i in range(num_blobs): |
|
center_x = random.randint(50, DISPLAY_DIMS[0]-50) |
|
center_y = random.randint(50, DISPLAY_DIMS[1]-50) |
|
radius = random.randint(10, 30) |
|
cv2.circle(mask, (center_x, center_y), radius, 1, -1) |
|
|
|
|
|
overlay = display_image.copy() |
|
|
|
|
|
overlay[:, :, 0] = np.maximum(overlay[:, :, 0], mask * 255) |
|
overlay[:, :, 1] = np.where(mask > 0, overlay[:, :, 1] * 0.5, overlay[:, :, 1]) |
|
overlay[:, :, 2] = np.where(mask > 0, overlay[:, :, 2] * 0.5, overlay[:, :, 2]) |
|
|
|
|
|
enhanced_mask = cv2.normalize(mask, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) |
|
enhanced_mask = cv2.applyColorMap(enhanced_mask, cv2.COLORMAP_JET) |
|
|
|
lesion_percentage = np.sum(mask) / mask.size * 100 |
|
|
|
|
|
return enhanced_mask, overlay |
|
|
|
|
|
def process_image(image): |
|
if image is None: |
|
return None, "No image provided", None, "No image provided" |
|
|
|
|
|
classification_result, processed_image, predicted_class = classify_image(image) |
|
|
|
|
|
|
|
segmentation_map, overlay_image = segment_image(image) |
|
|
|
|
|
|
|
|
|
|
|
return overlay_image, classification_result, segmentation_map, classification_result |
|
|
|
|
|
def load_covid_examples(): |
|
examples = [] |
|
|
|
try: |
|
|
|
for i in range(1, 6): |
|
covid_path = f"./examples/Covid ({i}).png" |
|
if os.path.exists(covid_path): |
|
examples.append([covid_path]) |
|
|
|
|
|
if len(examples) == 0: |
|
for i in range(1, 6): |
|
covid_img = np.ones((256, 256, 3), dtype=np.uint8) * 200 |
|
cv2.putText(covid_img, f"COVID Example {i}", (30, 128), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (100, 100, 100), 2) |
|
examples.append([covid_img]) |
|
except Exception as e: |
|
print(f"Could not load COVID examples: {e}") |
|
|
|
return examples |
|
|
|
def load_non_covid_examples(): |
|
examples = [] |
|
|
|
try: |
|
|
|
for i in range(1, 6): |
|
non_covid_path = f"./examples/Non-Covid ({i}).png" |
|
if os.path.exists(non_covid_path): |
|
examples.append([non_covid_path]) |
|
|
|
|
|
if len(examples) == 0: |
|
for i in range(1, 6): |
|
non_covid_img = np.ones((256, 256, 3), dtype=np.uint8) * 200 |
|
cv2.putText(non_covid_img, f"Non-COVID Example {i}", (30, 128), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (100, 100, 100), 2) |
|
examples.append([non_covid_img]) |
|
except Exception as e: |
|
print(f"Could not load Non-COVID examples: {e}") |
|
|
|
return examples |
|
|
|
class GradioInterface: |
|
def __init__(self): |
|
self.covid_examples = load_covid_examples() |
|
self.non_covid_examples = load_non_covid_examples() |
|
|
|
def create_interface(self): |
|
app_styles = """ |
|
<style> |
|
/* Global Styles */ |
|
body, #root { |
|
font-family: Helvetica, Arial, sans-serif; |
|
background-color: #1a1a1a; |
|
color: #fafafa; |
|
} |
|
/* Header Styles */ |
|
.app-header { |
|
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%); |
|
padding: 24px; |
|
border-radius: 8px; |
|
margin-bottom: 24px; |
|
text-align: center; |
|
} |
|
.app-title { |
|
font-size: 48px; |
|
margin: 0; |
|
color: #fafafa; |
|
} |
|
.app-subtitle { |
|
font-size: 24px; |
|
margin: 8px 0 16px; |
|
color: #fafafa; |
|
} |
|
.app-description { |
|
font-size: 16px; |
|
line-height: 1.6; |
|
opacity: 0.8; |
|
margin-bottom: 24px; |
|
} |
|
/* Button Styles */ |
|
.publication-links { |
|
display: flex; |
|
justify-content: center; |
|
flex-wrap: wrap; |
|
gap: 8px; |
|
margin-bottom: 16px; |
|
} |
|
.publication-link { |
|
display: inline-flex; |
|
align-items: center; |
|
padding: 8px 16px; |
|
background-color: #333; |
|
color: #fff !important; |
|
text-decoration: none !important; |
|
border-radius: 20px; |
|
font-size: 14px; |
|
transition: background-color 0.3s; |
|
} |
|
.publication-link:hover { |
|
background-color: #555; |
|
} |
|
.publication-link i { |
|
margin-right: 8px; |
|
} |
|
/* Content Styles */ |
|
.content-container { |
|
background-color: #2a2a2a; |
|
border-radius: 8px; |
|
padding: 24px; |
|
margin-bottom: 24px; |
|
} |
|
/* Image Styles */ |
|
.image-preview img { |
|
max-width: 256px; |
|
max-height: 256px; |
|
margin: 0 auto; |
|
border-radius: 4px; |
|
display: block; |
|
object-fit: contain; |
|
} |
|
/* Control Styles */ |
|
.control-panel { |
|
background-color: #333; |
|
padding: 16px; |
|
border-radius: 8px; |
|
margin-top: 16px; |
|
} |
|
/* Gradio Component Overrides */ |
|
.gr-button { |
|
background-color: #4a4a4a; |
|
color: #fff; |
|
border: none; |
|
border-radius: 4px; |
|
padding: 8px 16px; |
|
cursor: pointer; |
|
transition: background-color 0.3s; |
|
} |
|
.gr-button:hover { |
|
background-color: #5a5a5a; |
|
} |
|
.gr-input, .gr-dropdown { |
|
background-color: #3a3a3a; |
|
color: #fff; |
|
border: 1px solid #4a4a4a; |
|
border-radius: 4px; |
|
padding: 8px; |
|
} |
|
.gr-form { |
|
background-color: transparent; |
|
} |
|
.gr-panel { |
|
border: none; |
|
background-color: transparent; |
|
} |
|
</style> |
|
""" |
|
|
|
header_html = f""" |
|
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css"> |
|
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css"> |
|
{app_styles} |
|
<div class="app-header"> |
|
<h1 class="app-title">COVID-19 CT Analysis System</h1> |
|
<h2 class="app-subtitle">Classification & Lesion Segmentation</h2> |
|
<p class="app-description"> |
|
Upload CT scan images to detect COVID-19 and segment lesions if present. |
|
The system uses ResNet-18 for feature extraction and a U-Net for lesion segmentation. |
|
</p> |
|
</div> |
|
""" |
|
|
|
js_func = """ |
|
function refresh() { |
|
const url = new URL(window.location); |
|
if (url.searchParams.get('__theme') !== 'dark') { |
|
url.searchParams.set('__theme', 'dark'); |
|
window.location.href = url.href; |
|
} |
|
} |
|
""" |
|
|
|
with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo: |
|
gr.HTML(header_html) |
|
|
|
with gr.Row(elem_classes="content-container"): |
|
with gr.Column(): |
|
input_image = gr.Image(label="Upload CT Scan Image", type="pil", image_mode="RGB", elem_classes="image-preview") |
|
run_button = gr.Button("Analyze Image", elem_classes="gr-button") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
covid_examples_title = gr.Markdown("### COVID Examples") |
|
covid_examples = gr.Examples( |
|
examples=self.covid_examples, |
|
inputs=input_image, |
|
label="" |
|
) |
|
|
|
with gr.Column(scale=1): |
|
non_covid_examples_title = gr.Markdown("### Non-COVID Examples") |
|
non_covid_examples = gr.Examples( |
|
examples=self.non_covid_examples, |
|
inputs=input_image, |
|
label="" |
|
) |
|
|
|
with gr.Column(): |
|
with gr.Tab("Results"): |
|
overlay_image = gr.Image(label="Segmentation Overlay", elem_classes="image-preview") |
|
result_text = gr.Textbox(label="Analysis Results") |
|
|
|
with gr.Tab("Segmentation Details"): |
|
segmentation_image = gr.Image(label="Lesion Segmentation Map", elem_classes="image-preview") |
|
classification_text = gr.Textbox(label="Classification Details") |
|
|
|
run_button.click( |
|
fn=process_image, |
|
inputs=input_image, |
|
outputs=[overlay_image, result_text, segmentation_image, classification_text], |
|
) |
|
|
|
return demo |
|
|
|
def main(): |
|
interface = GradioInterface() |
|
demo = interface.create_interface() |
|
demo.launch(share=True) |
|
|
|
if __name__ == "__main__": |
|
main() |