import os import torch import gradio as gr from torchvision import models, transforms from datasets import load_dataset from torch.utils.data import DataLoader from torch import nn, optim from PIL import Image # Define transformation transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Load and train model if not already saved model_path = "model.pth" if not os.path.exists(model_path): dataset = load_dataset("hongrui/mimic_chest_xray_v_1", split="train") class XRayDataset(torch.utils.data.Dataset): def __init__(self, dataset): self.dataset = dataset def __len__(self): return len(self.dataset) def __getitem__(self, idx): img = Image.open(self.dataset[idx]['image_path']).convert("RGB") label = int(self.dataset[idx]['label']) return transform(img), label train_dataset = XRayDataset(dataset) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) model = models.resnet18(pretrained=True) model.fc = nn.Linear(model.fc.in_features, 2) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) model.train() for epoch in range(3): # Few epochs for speed for inputs, labels in train_loader: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() torch.save(model.state_dict(), model_path) # Load model for inference model = models.resnet18(pretrained=False) model.fc = nn.Linear(model.fc.in_features, 2) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.load_state_dict(torch.load(model_path, map_location=device)) model = model.to(device) model.eval() # Inference function def predict(img): img = transform(img).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(img) _, predicted = torch.max(outputs, 1) return "Abnormal" if predicted.item() == 1 else "Normal" # Gradio app interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="Chest X-ray Classification", description="Upload a chest X-ray image to classify it as Normal or Abnormal." ) interface.launch()