|
import os |
|
import torch |
|
import gradio as gr |
|
from torchvision import models, transforms |
|
from datasets import load_dataset |
|
from torch.utils.data import DataLoader |
|
from torch import nn, optim |
|
from PIL import Image |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
|
|
|
|
model_path = "model.pth" |
|
|
|
if not os.path.exists(model_path): |
|
dataset = load_dataset("hongrui/mimic_chest_xray_v_1", split="train") |
|
|
|
class XRayDataset(torch.utils.data.Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
img = Image.open(self.dataset[idx]['image_path']).convert("RGB") |
|
label = int(self.dataset[idx]['label']) |
|
return transform(img), label |
|
|
|
train_dataset = XRayDataset(dataset) |
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) |
|
|
|
model = models.resnet18(pretrained=True) |
|
model.fc = nn.Linear(model.fc.in_features, 2) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = model.to(device) |
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=0.001) |
|
|
|
model.train() |
|
for epoch in range(3): |
|
for inputs, labels in train_loader: |
|
inputs, labels = inputs.to(device), labels.to(device) |
|
optimizer.zero_grad() |
|
outputs = model(inputs) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
torch.save(model.state_dict(), model_path) |
|
|
|
|
|
model = models.resnet18(pretrained=False) |
|
model.fc = nn.Linear(model.fc.in_features, 2) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.load_state_dict(torch.load(model_path, map_location=device)) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
def predict(img): |
|
img = transform(img).unsqueeze(0).to(device) |
|
with torch.no_grad(): |
|
outputs = model(img) |
|
_, predicted = torch.max(outputs, 1) |
|
return "Abnormal" if predicted.item() == 1 else "Normal" |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs="text", |
|
title="Chest X-ray Classification", |
|
description="Upload a chest X-ray image to classify it as Normal or Abnormal." |
|
) |
|
|
|
interface.launch() |
|
|