Spaces:
Sleeping
Sleeping
File size: 6,216 Bytes
eaf1c02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import gradio as gr
import torch
import torchvision.models as models
from torchvision.models import EfficientNet_B0_Weights # Or the specific version used
from PIL import Image
from torchvision import transforms
import json
from huggingface_hub import hf_hub_download
import os
# --- Configuration ---
# This should be the ID of the repository where your MODEL is stored
MODEL_REPO_ID = "bhumong/fruit-classifier-efficientnet-b0" # <-- REPLACE if different
MODEL_FILENAME = "pytorch_model.bin"
CONFIG_FILENAME = "config.json"
# --- 1. Load Model and Config ---
# (Using the function defined previously to load from Hub)
def load_model_from_hf(repo_id, model_filename, config_filename):
"""Loads model state_dict and config from Hugging Face Hub."""
try:
config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
with open(config_path, 'r') as f:
config = json.load(f)
print("Config loaded:", config) # Debug print
except Exception as e:
print(f"Error loading config from {repo_id}/{config_filename}: {e}")
raise # Re-raise error if config fails
num_labels = config.get('num_labels')
id2label = config.get('id2label')
if num_labels is None or id2label is None:
raise ValueError("Config file must contain 'num_labels' and 'id2label'")
# Instantiate the correct architecture (EfficientNet-B0)
model = models.efficientnet_b0(weights=None) # Load architecture only
# Modify the classifier head
try:
num_ftrs = model.classifier[1].in_features
model.classifier[1] = torch.nn.Linear(num_ftrs, num_labels)
except Exception as e:
print(f"Error modifying model classifier: {e}")
raise
# Download and load model weights
try:
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.to(device) # Move model to device
model.eval() # Set to evaluation mode
print(f"Model loaded successfully from {repo_id} to device {device}.")
return model, config, id2label, device
except Exception as e:
print(f"Error loading model weights from {repo_id}/{model_filename}: {e}")
raise
# Load the model globally when the script starts
try:
model, config, id2label, device = load_model_from_hf(MODEL_REPO_ID, MODEL_FILENAME, CONFIG_FILENAME)
except Exception as e:
print(f"FATAL: Could not load model or config. Gradio app cannot start. Error: {e}")
# Optionally, exit or raise a specific error for Gradio to catch if possible
model, config, id2label, device = None, None, None, None # Prevent further errors
# --- 2. Define Preprocessing ---
IMG_SIZE = (224, 224)
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
preprocess = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
# --- 3. Define Prediction Function ---
def predict(inp_image):
"""Takes a PIL image, preprocesses, predicts, and returns label confidences."""
if model is None or id2label is None:
return {"Error": 1.0, "Message": "Model not loaded"} # Handle model load failure
if inp_image is None:
return {"Error": 1.0, "Message": "No image provided"}
try:
# Ensure image is RGB
img = inp_image.convert("RGB")
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0) # Add batch dimension
input_batch = input_batch.to(device) # Move tensor to the correct device
with torch.no_grad():
output = model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Prepare output for Gradio Label component (dictionary {label: probability})
confidences = {id2label[str(i)]: float(probabilities[i]) for i in range(len(id2label))}
return confidences
except Exception as e:
print(f"Error during prediction: {e}")
return {"Error": 1.0, "Message": f"Prediction failed: {e}"}
# --- 4. Create Gradio Interface ---
# Add example images (Make sure these paths exist within your Space repo!)
# Create an 'images' folder in your Space and upload some examples.
example_list = [
["images/example_apple.jpg"], # <-- REPLACE with actual paths in your Space repo
["images/example_banana.jpg"], # <-- REPLACE with actual paths in your Space repo
["images/example_strawberry.jpg"] # <-- REPLACE with actual paths in your Space repo
]
# Check if example files exist, otherwise provide empty list
if not all(os.path.exists(ex[0]) for ex in example_list):
print("Warning: Example image paths not found. Clearing examples.")
example_list = []
# Define Title, Description, and Article for the Gradio app
title = "Fruit Classifier πππ"
description = """
Upload an image of a fruit or use one of the examples below.
This demo uses an EfficientNet-B0 model fine-tuned on the Fruits-360 dataset
(with merged classes) to predict the fruit type.
Model hosted on Hugging Face Hub: [{MODEL_REPO_ID}](https://huggingface.co/{MODEL_REPO_ID})
""".format(MODEL_REPO_ID=MODEL_REPO_ID) # Format description with repo ID
article = """
<div style='text-align: center;'>
Model trained using PyTorch and tracked with Neptune.ai. |
<a href='https://huggingface.co/{MODEL_REPO_ID}' target='_blank'>Model Repository</a> |
Built with Gradio
</div>
""".format(MODEL_REPO_ID=MODEL_REPO_ID)
# Create and launch the interface
if model is not None: # Only launch if model loaded successfully
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Fruit Image"),
outputs=gr.Label(num_top_classes=5, label="Predictions"), # Show top 5 predictions
title=title,
description=description,
article=article,
examples=example_list,
allow_flagging="never" # Optional: disable flagging
)
iface.launch()
else:
print("Gradio interface not launched due to model loading failure.")
|