Mrhuman1's picture
Update app.py
542e063 verified
import os
import torch
import gradio as gr
from torchvision import models, transforms
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch import nn, optim
from PIL import Image
# Define transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load and train model if not already saved
model_path = "model.pth"
if not os.path.exists(model_path):
dataset = load_dataset("hongrui/mimic_chest_xray_v_1", split="train")
class XRayDataset(torch.utils.data.Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img = Image.open(self.dataset[idx]['image_path']).convert("RGB")
label = int(self.dataset[idx]['label'])
return transform(img), label
train_dataset = XRayDataset(dataset)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
for epoch in range(3): # Few epochs for speed
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
torch.save(model.state_dict(), model_path)
# Load model for inference
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()
# Inference function
def predict(img):
img = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img)
_, predicted = torch.max(outputs, 1)
return "Abnormal" if predicted.item() == 1 else "Normal"
# Gradio app
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="Chest X-ray Classification",
description="Upload a chest X-ray image to classify it as Normal or Abnormal."
)
interface.launch()