File size: 6,216 Bytes
eaf1c02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
import torch
import torchvision.models as models
from torchvision.models import EfficientNet_B0_Weights # Or the specific version used
from PIL import Image
from torchvision import transforms
import json
from huggingface_hub import hf_hub_download
import os

# --- Configuration ---
# This should be the ID of the repository where your MODEL is stored
MODEL_REPO_ID = "bhumong/fruit-classifier-efficientnet-b0" # <-- REPLACE if different
MODEL_FILENAME = "pytorch_model.bin"
CONFIG_FILENAME = "config.json"

# --- 1. Load Model and Config ---
# (Using the function defined previously to load from Hub)
def load_model_from_hf(repo_id, model_filename, config_filename):
    """Loads model state_dict and config from Hugging Face Hub."""
    try:
        config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
        with open(config_path, 'r') as f:
            config = json.load(f)
        print("Config loaded:", config) # Debug print
    except Exception as e:
        print(f"Error loading config from {repo_id}/{config_filename}: {e}")
        raise # Re-raise error if config fails

    num_labels = config.get('num_labels')
    id2label = config.get('id2label')

    if num_labels is None or id2label is None:
        raise ValueError("Config file must contain 'num_labels' and 'id2label'")

    # Instantiate the correct architecture (EfficientNet-B0)
    model = models.efficientnet_b0(weights=None) # Load architecture only

    # Modify the classifier head
    try:
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = torch.nn.Linear(num_ftrs, num_labels)
    except Exception as e:
        print(f"Error modifying model classifier: {e}")
        raise

    # Download and load model weights
    try:
        model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        state_dict = torch.load(model_path, map_location=device)
        model.load_state_dict(state_dict)
        model.to(device) # Move model to device
        model.eval() # Set to evaluation mode
        print(f"Model loaded successfully from {repo_id} to device {device}.")
        return model, config, id2label, device
    except Exception as e:
        print(f"Error loading model weights from {repo_id}/{model_filename}: {e}")
        raise

# Load the model globally when the script starts
try:
    model, config, id2label, device = load_model_from_hf(MODEL_REPO_ID, MODEL_FILENAME, CONFIG_FILENAME)
except Exception as e:
    print(f"FATAL: Could not load model or config. Gradio app cannot start. Error: {e}")
    # Optionally, exit or raise a specific error for Gradio to catch if possible
    model, config, id2label, device = None, None, None, None # Prevent further errors


# --- 2. Define Preprocessing ---
IMG_SIZE = (224, 224)
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

preprocess = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# --- 3. Define Prediction Function ---
def predict(inp_image):
    """Takes a PIL image, preprocesses, predicts, and returns label confidences."""
    if model is None or id2label is None:
         return {"Error": 1.0, "Message": "Model not loaded"} # Handle model load failure

    if inp_image is None:
        return {"Error": 1.0, "Message": "No image provided"}

    try:
        # Ensure image is RGB
        img = inp_image.convert("RGB")
        input_tensor = preprocess(img)
        input_batch = input_tensor.unsqueeze(0) # Add batch dimension
        input_batch = input_batch.to(device) # Move tensor to the correct device

        with torch.no_grad():
            output = model(input_batch)
            probabilities = torch.nn.functional.softmax(output[0], dim=0)

        # Prepare output for Gradio Label component (dictionary {label: probability})
        confidences = {id2label[str(i)]: float(probabilities[i]) for i in range(len(id2label))}
        return confidences

    except Exception as e:
        print(f"Error during prediction: {e}")
        return {"Error": 1.0, "Message": f"Prediction failed: {e}"}


# --- 4. Create Gradio Interface ---

# Add example images (Make sure these paths exist within your Space repo!)
# Create an 'images' folder in your Space and upload some examples.
example_list = [
    ["images/example_apple.jpg"],      # <-- REPLACE with actual paths in your Space repo
    ["images/example_banana.jpg"],     # <-- REPLACE with actual paths in your Space repo
    ["images/example_strawberry.jpg"]  # <-- REPLACE with actual paths in your Space repo
]
# Check if example files exist, otherwise provide empty list
if not all(os.path.exists(ex[0]) for ex in example_list):
     print("Warning: Example image paths not found. Clearing examples.")
     example_list = []


# Define Title, Description, and Article for the Gradio app
title = "Fruit Classifier πŸŽπŸŒπŸ“"
description = """
Upload an image of a fruit or use one of the examples below.
This demo uses an EfficientNet-B0 model fine-tuned on the Fruits-360 dataset
(with merged classes) to predict the fruit type.
Model hosted on Hugging Face Hub: [{MODEL_REPO_ID}](https://huggingface.co/{MODEL_REPO_ID})
""".format(MODEL_REPO_ID=MODEL_REPO_ID) # Format description with repo ID
article = """
<div style='text-align: center;'>
Model trained using PyTorch and tracked with Neptune.ai. |
<a href='https://huggingface.co/{MODEL_REPO_ID}' target='_blank'>Model Repository</a> |
Built with Gradio
</div>
""".format(MODEL_REPO_ID=MODEL_REPO_ID)

# Create and launch the interface
if model is not None: # Only launch if model loaded successfully
    iface = gr.Interface(
        fn=predict,
        inputs=gr.Image(type="pil", label="Upload Fruit Image"),
        outputs=gr.Label(num_top_classes=5, label="Predictions"), # Show top 5 predictions
        title=title,
        description=description,
        article=article,
        examples=example_list,
        allow_flagging="never" # Optional: disable flagging
    )

    iface.launch()
else:
    print("Gradio interface not launched due to model loading failure.")