File size: 2,585 Bytes
542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 1741aed 542e063 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
import torch
import gradio as gr
from torchvision import models, transforms
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch import nn, optim
from PIL import Image
# Define transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load and train model if not already saved
model_path = "model.pth"
if not os.path.exists(model_path):
dataset = load_dataset("hongrui/mimic_chest_xray_v_1", split="train")
class XRayDataset(torch.utils.data.Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img = Image.open(self.dataset[idx]['image_path']).convert("RGB")
label = int(self.dataset[idx]['label'])
return transform(img), label
train_dataset = XRayDataset(dataset)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
for epoch in range(3): # Few epochs for speed
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
torch.save(model.state_dict(), model_path)
# Load model for inference
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()
# Inference function
def predict(img):
img = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img)
_, predicted = torch.max(outputs, 1)
return "Abnormal" if predicted.item() == 1 else "Normal"
# Gradio app
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="Chest X-ray Classification",
description="Upload a chest X-ray image to classify it as Normal or Abnormal."
)
interface.launch()
|