Spaces:
Sleeping
Sleeping
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
import cv2 | |
import zipfile | |
import os | |
IMG_SIZE = 256 | |
# Extract SavedModel | |
if not os.path.exists("lung_cancer_cnn_saved_model"): | |
print("Extracting SavedModel...") | |
os.makedirs("lung_cancer_cnn_saved_model", exist_ok=True) | |
with zipfile.ZipFile("lung_cancer_cnn_saved_model.zip", 'r') as zip_ref: | |
zip_ref.extractall("lung_cancer_cnn_saved_model") | |
if not os.path.exists("lung_cancer_cnn_saved_model/saved_model.pb"): | |
raise FileNotFoundError("saved_model.pb not found") | |
print("SavedModel extracted") | |
print("Loading model...") | |
model = tf.saved_model.load("lung_cancer_cnn_saved_model") | |
infer = model.signatures['serving_default'] | |
class_names = ['Normal', 'Lung Cancer'] | |
print("Model loaded") | |
def predict(image): | |
try: | |
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) | |
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) / 255.0 | |
img = img.reshape(1, IMG_SIZE, IMG_SIZE, 1).astype(np.float32) | |
if img.max() > 1.0 or img.min() < 0.0: | |
raise ValueError("Image not normalized") | |
input_key = list(infer.structured_input_signature[1].keys())[0] | |
inputs = {input_key: tf.convert_to_tensor(img)} | |
outputs = infer(**inputs) | |
output_key = list(outputs.keys())[0] | |
pred = outputs[output_key].numpy() | |
print("Raw prediction:", pred[0][0]) | |
class_id = 1 if pred[0][0] > 0.5 else 0 | |
confidence = pred[0][0] if class_id == 1 else 1 - pred[0][0] | |
return f"Classified as: {class_names[class_id]}\nConfidence: {confidence:.4f}" | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil", label="Upload Chest CT Image"), | |
outputs=gr.Textbox(label="Classification Result"), | |
title="🦋 Lung Cancer Classification CNN", | |
description="Classify chest CT images as Normal or Lung Cancer (~84% accuracy).", | |
examples=None, | |
cache_examples=False | |
) | |
if __name__ == "__main__": | |
iface.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True) | |