Spaces:
Sleeping
Sleeping
File size: 2,170 Bytes
eb66ea6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
import zipfile
import os
IMG_SIZE = 256
# Extract SavedModel
if not os.path.exists("lung_cancer_cnn_saved_model"):
print("Extracting SavedModel...")
os.makedirs("lung_cancer_cnn_saved_model", exist_ok=True)
with zipfile.ZipFile("lung_cancer_cnn_saved_model.zip", 'r') as zip_ref:
zip_ref.extractall("lung_cancer_cnn_saved_model")
if not os.path.exists("lung_cancer_cnn_saved_model/saved_model.pb"):
raise FileNotFoundError("saved_model.pb not found")
print("SavedModel extracted")
print("Loading model...")
model = tf.saved_model.load("lung_cancer_cnn_saved_model")
infer = model.signatures['serving_default']
class_names = ['Normal', 'Lung Cancer']
print("Model loaded")
def predict(image):
try:
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
img = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) / 255.0
img = img.reshape(1, IMG_SIZE, IMG_SIZE, 1).astype(np.float32)
if img.max() > 1.0 or img.min() < 0.0:
raise ValueError("Image not normalized")
input_key = list(infer.structured_input_signature[1].keys())[0]
inputs = {input_key: tf.convert_to_tensor(img)}
outputs = infer(**inputs)
output_key = list(outputs.keys())[0]
pred = outputs[output_key].numpy()
print("Raw prediction:", pred[0][0])
class_id = 1 if pred[0][0] > 0.5 else 0
confidence = pred[0][0] if class_id == 1 else 1 - pred[0][0]
return f"Classified as: {class_names[class_id]}\nConfidence: {confidence:.4f}"
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Chest CT Image"),
outputs=gr.Textbox(label="Classification Result"),
title="🦋 Lung Cancer Classification CNN",
description="Classify chest CT images as Normal or Lung Cancer (~84% accuracy).",
examples=None,
cache_examples=False
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)
|