File size: 2,170 Bytes
eb66ea6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

import gradio as gr
import tensorflow as tf
import numpy as np
import cv2
import zipfile
import os

IMG_SIZE = 256

# Extract SavedModel
if not os.path.exists("lung_cancer_cnn_saved_model"):
    print("Extracting SavedModel...")
    os.makedirs("lung_cancer_cnn_saved_model", exist_ok=True)
    with zipfile.ZipFile("lung_cancer_cnn_saved_model.zip", 'r') as zip_ref:
        zip_ref.extractall("lung_cancer_cnn_saved_model")
    if not os.path.exists("lung_cancer_cnn_saved_model/saved_model.pb"):
        raise FileNotFoundError("saved_model.pb not found")
    print("SavedModel extracted")

print("Loading model...")
model = tf.saved_model.load("lung_cancer_cnn_saved_model")
infer = model.signatures['serving_default']
class_names = ['Normal', 'Lung Cancer']
print("Model loaded")

def predict(image):
    try:
        img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE)) / 255.0
        img = img.reshape(1, IMG_SIZE, IMG_SIZE, 1).astype(np.float32)
        if img.max() > 1.0 or img.min() < 0.0:
            raise ValueError("Image not normalized")
        
        input_key = list(infer.structured_input_signature[1].keys())[0]
        inputs = {input_key: tf.convert_to_tensor(img)}
        outputs = infer(**inputs)
        output_key = list(outputs.keys())[0]
        pred = outputs[output_key].numpy()
        print("Raw prediction:", pred[0][0])
        class_id = 1 if pred[0][0] > 0.5 else 0
        confidence = pred[0][0] if class_id == 1 else 1 - pred[0][0]
        
        return f"Classified as: {class_names[class_id]}\nConfidence: {confidence:.4f}"
    except Exception as e:
        return f"Error: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Chest CT Image"),
    outputs=gr.Textbox(label="Classification Result"),
    title="🦋 Lung Cancer Classification CNN",
    description="Classify chest CT images as Normal or Lung Cancer (~84% accuracy).",
    examples=None,
    cache_examples=False
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)