Mrhuman1 commited on
Commit
542e063
·
verified ·
1 Parent(s): f0e9d36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -67
app.py CHANGED
@@ -1,91 +1,83 @@
1
- import streamlit as st
2
- from datasets import load_dataset
3
- from torchvision import models, transforms
4
  import torch
5
- import torch.nn as nn
 
 
6
  from torch.utils.data import DataLoader
7
- from torchvision.datasets.folder import default_loader
8
  from PIL import Image
9
- import os
10
-
11
- st.set_page_config(page_title="Chest X-ray Classifier", layout="centered")
12
- MODEL_PATH = "resnet_chest_xray.pt"
13
 
14
  # Define transformation
15
  transform = transforms.Compose([
16
  transforms.Resize((224, 224)),
17
  transforms.ToTensor(),
18
- transforms.Normalize([0.485], [0.229])
19
  ])
20
 
21
- # Define model
22
- def get_model():
23
- model = models.resnet34(pretrained=True)
24
- model.fc = nn.Linear(model.fc.in_features, 15)
25
- return model
26
 
27
- # Train the model once if not already saved
28
- @st.cache_resource
29
- def train_model_once():
30
- if os.path.exists(MODEL_PATH):
31
- model = get_model()
32
- model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
33
- model.eval()
34
- return model
35
 
36
- st.write("🔄 Training model (first time only)...")
37
- dataset = load_dataset("BahaaEldin0/NIH-Chest-Xray-14", split="train[:1%]") # small subset
 
38
 
39
- # Preprocess dataset
40
- images, labels = [], []
41
- label_names = dataset.features["labels"].feature.names
42
- for sample in dataset:
43
- img = default_loader(sample["image"].filename)
44
- images.append(transform(img))
45
- label_vector = torch.zeros(15)
46
- for l in sample["labels"]:
47
- label_vector[l] = 1
48
- labels.append(label_vector)
49
 
50
- dataloader = DataLoader(list(zip(images, labels)), batch_size=8, shuffle=True)
 
 
 
51
 
52
- model = get_model()
53
- criterion = nn.BCEWithLogitsLoss()
54
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
 
 
 
 
 
 
 
55
 
56
  model.train()
57
- for epoch in range(2): # few epochs
58
- for x, y in dataloader:
 
59
  optimizer.zero_grad()
60
- output = model(x)
61
- loss = criterion(output, y)
62
  loss.backward()
63
  optimizer.step()
64
 
65
- torch.save(model.state_dict(), MODEL_PATH)
66
- model.eval()
67
- return model
68
-
69
- # App interface
70
- st.title("🩻 Chest X-ray Classifier")
71
 
72
- model = train_model_once()
73
-
74
- uploaded_file = st.file_uploader("Upload a Chest X-ray Image", type=["jpg", "jpeg", "png"])
75
-
76
- if uploaded_file:
77
- image = Image.open(uploaded_file).convert("RGB")
78
- st.image(image, caption="Uploaded Image", use_column_width=True)
79
- input_tensor = transform(image).unsqueeze(0)
80
 
 
 
 
81
  with torch.no_grad():
82
- output = model(input_tensor)
83
- pred = torch.sigmoid(output).squeeze()
84
- label_names = ["Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass",
85
- "Nodule", "Pneumonia", "Pneumothorax", "Consolidation", "Edema",
86
- "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia", "No Finding"]
87
-
88
- st.subheader("Prediction Results:")
89
- for idx, val in enumerate(pred):
90
- if val.item() > 0.5:
91
- st.markdown(f"- **{label_names[idx]}**: {val.item():.2f}")
 
 
 
 
 
1
+ import os
 
 
2
  import torch
3
+ import gradio as gr
4
+ from torchvision import models, transforms
5
+ from datasets import load_dataset
6
  from torch.utils.data import DataLoader
7
+ from torch import nn, optim
8
  from PIL import Image
 
 
 
 
9
 
10
  # Define transformation
11
  transform = transforms.Compose([
12
  transforms.Resize((224, 224)),
13
  transforms.ToTensor(),
14
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
15
  ])
16
 
17
+ # Load and train model if not already saved
18
+ model_path = "model.pth"
 
 
 
19
 
20
+ if not os.path.exists(model_path):
21
+ dataset = load_dataset("hongrui/mimic_chest_xray_v_1", split="train")
 
 
 
 
 
 
22
 
23
+ class XRayDataset(torch.utils.data.Dataset):
24
+ def __init__(self, dataset):
25
+ self.dataset = dataset
26
 
27
+ def __len__(self):
28
+ return len(self.dataset)
 
 
 
 
 
 
 
 
29
 
30
+ def __getitem__(self, idx):
31
+ img = Image.open(self.dataset[idx]['image_path']).convert("RGB")
32
+ label = int(self.dataset[idx]['label'])
33
+ return transform(img), label
34
 
35
+ train_dataset = XRayDataset(dataset)
36
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
37
+
38
+ model = models.resnet18(pretrained=True)
39
+ model.fc = nn.Linear(model.fc.in_features, 2)
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
+ model = model.to(device)
42
+
43
+ criterion = nn.CrossEntropyLoss()
44
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
45
 
46
  model.train()
47
+ for epoch in range(3): # Few epochs for speed
48
+ for inputs, labels in train_loader:
49
+ inputs, labels = inputs.to(device), labels.to(device)
50
  optimizer.zero_grad()
51
+ outputs = model(inputs)
52
+ loss = criterion(outputs, labels)
53
  loss.backward()
54
  optimizer.step()
55
 
56
+ torch.save(model.state_dict(), model_path)
 
 
 
 
 
57
 
58
+ # Load model for inference
59
+ model = models.resnet18(pretrained=False)
60
+ model.fc = nn.Linear(model.fc.in_features, 2)
61
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ model.load_state_dict(torch.load(model_path, map_location=device))
63
+ model = model.to(device)
64
+ model.eval()
 
65
 
66
+ # Inference function
67
+ def predict(img):
68
+ img = transform(img).unsqueeze(0).to(device)
69
  with torch.no_grad():
70
+ outputs = model(img)
71
+ _, predicted = torch.max(outputs, 1)
72
+ return "Abnormal" if predicted.item() == 1 else "Normal"
73
+
74
+ # Gradio app
75
+ interface = gr.Interface(
76
+ fn=predict,
77
+ inputs=gr.Image(type="pil"),
78
+ outputs="text",
79
+ title="Chest X-ray Classification",
80
+ description="Upload a chest X-ray image to classify it as Normal or Abnormal."
81
+ )
82
+
83
+ interface.launch()