import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image import gradio as gr # Define the model architecture (same as in your original code) class ImageClassificationBase(nn.Module): def validation_step(self, batch): images, labels = batch out = self(images) loss = F.cross_entropy(out, labels) acc = accuracy(out, labels) return {'val_loss': loss.detach(), 'val_acc': acc} def validation_epoch_end(self, outputs): batch_losses = [x['val_loss'] for x in outputs] epoch_loss = torch.stack(batch_losses).mean() batch_accs = [x['val_acc'] for x in outputs] epoch_acc = torch.stack(batch_accs).mean() return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()} def accuracy(outputs, labels): _, preds = torch.max(outputs, dim=1) return torch.tensor(torch.sum(preds == labels).item() / len(preds)) class Classifier(ImageClassificationBase): def __init__(self): super().__init__() self.network = nn.Sequential( nn.Conv2d(3, 12, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(3, 3), nn.Conv2d(12, 15, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(3, 3), nn.Conv2d(15, 10, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(3, 3), nn.Flatten(), nn.Linear(810, 2), ) def forward(self, xb): return self.network(xb) # Load the trained model device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Classifier().to(device) # Upload model file to Colab if needed # from google.colab import files # uploaded = files.upload() # Upload the .pth file # Load the model weights model.load_state_dict(torch.load('PCOS_detection_20_epochs_val_acc_1.0.pth', map_location=device)) model.eval() # Define class names class_names = ['infected', 'not_infected'] # Update this if your classes are different # Define the preprocessing transform transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor() ]) # Function to make predictions on an image def predict_image(img): # Convert to PIL Image if it's not already if not isinstance(img, Image.Image): img = Image.fromarray(img) # Apply transformations img_tensor = transform(img).unsqueeze(0).to(device) # Get predictions with torch.no_grad(): outputs = model(img_tensor) _, preds = torch.max(outputs, 1) confidence = F.softmax(outputs, dim=1)[0] # Get class name and confidence pred_class = class_names[preds[0].item()] conf_score = confidence[preds[0]].item() # Prepare result dictionary result = { class_names[0]: float(confidence[0]), class_names[1]: float(confidence[1]) } return result # Create Gradio interface title = "PCOS Detection from Ultrasound Images" description = """ Upload an ultrasound image to detect PCOS (Polycystic Ovary Syndrome). The model will classify the image as either 'infected' (PCOS positive) or 'not_infected' (PCOS negative). """ # Create and launch the interface demo = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=2), title=title, description=description, examples=[ # You can add example images here if you have them ] ) demo.launch(debug=True, share=True)