File size: 2,585 Bytes
542e063
1741aed
542e063
 
 
1741aed
542e063
1741aed
 
 
 
 
 
542e063
1741aed
 
542e063
 
1741aed
542e063
 
1741aed
542e063
 
 
1741aed
542e063
 
1741aed
542e063
 
 
 
1741aed
542e063
 
 
 
 
 
 
 
 
 
1741aed
 
542e063
 
 
1741aed
542e063
 
1741aed
 
 
542e063
1741aed
542e063
 
 
 
 
 
 
1741aed
542e063
 
 
1741aed
542e063
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import torch
import gradio as gr
from torchvision import models, transforms
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch import nn, optim
from PIL import Image

# Define transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load and train model if not already saved
model_path = "model.pth"

if not os.path.exists(model_path):
    dataset = load_dataset("hongrui/mimic_chest_xray_v_1", split="train")

    class XRayDataset(torch.utils.data.Dataset):
        def __init__(self, dataset):
            self.dataset = dataset

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, idx):
            img = Image.open(self.dataset[idx]['image_path']).convert("RGB")
            label = int(self.dataset[idx]['label'])
            return transform(img), label

    train_dataset = XRayDataset(dataset)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    model.train()
    for epoch in range(3):  # Few epochs for speed
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    torch.save(model.state_dict(), model_path)

# Load model for inference
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()

# Inference function
def predict(img):
    img = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(img)
        _, predicted = torch.max(outputs, 1)
    return "Abnormal" if predicted.item() == 1 else "Normal"

# Gradio app
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs="text",
    title="Chest X-ray Classification",
    description="Upload a chest X-ray image to classify it as Normal or Abnormal."
)

interface.launch()