Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import gradio as gr | |
# Define the model architecture (same as in your original code) | |
class ImageClassificationBase(nn.Module): | |
def validation_step(self, batch): | |
images, labels = batch | |
out = self(images) | |
loss = F.cross_entropy(out, labels) | |
acc = accuracy(out, labels) | |
return {'val_loss': loss.detach(), 'val_acc': acc} | |
def validation_epoch_end(self, outputs): | |
batch_losses = [x['val_loss'] for x in outputs] | |
epoch_loss = torch.stack(batch_losses).mean() | |
batch_accs = [x['val_acc'] for x in outputs] | |
epoch_acc = torch.stack(batch_accs).mean() | |
return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()} | |
def accuracy(outputs, labels): | |
_, preds = torch.max(outputs, dim=1) | |
return torch.tensor(torch.sum(preds == labels).item() / len(preds)) | |
class Classifier(ImageClassificationBase): | |
def __init__(self): | |
super().__init__() | |
self.network = nn.Sequential( | |
nn.Conv2d(3, 12, kernel_size=3, padding=1), | |
nn.ReLU(), | |
nn.MaxPool2d(3, 3), | |
nn.Conv2d(12, 15, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(), | |
nn.MaxPool2d(3, 3), | |
nn.Conv2d(15, 10, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(), | |
nn.MaxPool2d(3, 3), | |
nn.Flatten(), | |
nn.Linear(810, 2), | |
) | |
def forward(self, xb): | |
return self.network(xb) | |
# Load the trained model | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = Classifier().to(device) | |
# Upload model file to Colab if needed | |
# from google.colab import files | |
# uploaded = files.upload() # Upload the .pth file | |
# Load the model weights | |
model.load_state_dict(torch.load('PCOS_detection_20_epochs_val_acc_1.0.pth', | |
map_location=device)) | |
model.eval() | |
# Define class names | |
class_names = ['infected', 'not_infected'] # Update this if your classes are different | |
# Define the preprocessing transform | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor() | |
]) | |
# Function to make predictions on an image | |
def predict_image(img): | |
# Convert to PIL Image if it's not already | |
if not isinstance(img, Image.Image): | |
img = Image.fromarray(img) | |
# Apply transformations | |
img_tensor = transform(img).unsqueeze(0).to(device) | |
# Get predictions | |
with torch.no_grad(): | |
outputs = model(img_tensor) | |
_, preds = torch.max(outputs, 1) | |
confidence = F.softmax(outputs, dim=1)[0] | |
# Get class name and confidence | |
pred_class = class_names[preds[0].item()] | |
conf_score = confidence[preds[0]].item() | |
# Prepare result dictionary | |
result = { | |
class_names[0]: float(confidence[0]), | |
class_names[1]: float(confidence[1]) | |
} | |
return result | |
# Create Gradio interface | |
title = "PCOS Detection from Ultrasound Images" | |
description = """ | |
Upload an ultrasound image to detect PCOS (Polycystic Ovary Syndrome). | |
The model will classify the image as either 'infected' (PCOS positive) or 'not_infected' (PCOS negative). | |
""" | |
# Create and launch the interface | |
demo = gr.Interface( | |
fn=predict_image, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=2), | |
title=title, | |
description=description, | |
examples=[ | |
# You can add example images here if you have them | |
] | |
) | |
demo.launch(debug=True, share=True) |