File size: 3,663 Bytes
3b23212
 
 
 
 
 
839ec63
e9bb170
 
 
 
 
839ec63
 
 
 
e9bb170
839ec63
 
3b23212
752ca4f
4a739e5
752ca4f
4a739e5
e9bb170
4a739e5
e9bb170
752ca4f
3b23212
839ec63
4a739e5
752ca4f
e9bb170
4a739e5
e9bb170
4a739e5
3b23212
752ca4f
4a739e5
 
 
752ca4f
 
4a739e5
07270f5
e9bb170
4a739e5
e9bb170
752ca4f
4a739e5
e9bb170
752ca4f
3b23212
839ec63
3b23212
4a739e5
 
 
752ca4f
4a739e5
 
 
 
752ca4f
4a739e5
e9bb170
4a739e5
 
e9bb170
4a739e5
3b23212
839ec63
07270f5
752ca4f
3b23212
4a739e5
 
e9bb170
4a739e5
 
e9bb170
3b23212
839ec63
07270f5
 
 
 
 
 
 
 
 
 
e9bb170
 
07270f5
 
839ec63
 
 
e9bb170
839ec63
e9bb170
839ec63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import os
import torch
from model import create_effnetb2_model
from timeit import default_timer as timer
from typing import Tuple, Dict
import pkg_resources
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Check Gradio version
try:
    gradio_version = pkg_resources.get_distribution("gradio").version
    logger.info(f"Using Gradio version: {gradio_version}")
except pkg_resources.DistributionNotFound:
    raise ImportError("Gradio is not installed. Please install it using 'pip install gradio'.")

# Load class names
try:
    with open("class_names.txt", "r") as f:
        class_names = [food_name.strip() for food_name in f.readlines()]
    logger.info("Class names loaded successfully")
except FileNotFoundError:
    logger.error("class_names.txt not found")
    raise FileNotFoundError("class_names.txt not found.")

# Model and transforms preparation
try:
    effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=101)
    logger.info("EfficientNetB2 model created successfully")
except Exception as e:
    logger.error(f"Error creating model: {str(e)}")
    raise Exception(f"Error creating model: {str(e)}")

# Load weights
try:
    effnetb2.load_state_dict(
        torch.load(
            "09_pretrained_effnetb2_feature_extractor_food101.pth",
            map_location=torch.device("cpu"),
        )
    )
    logger.info("Model weights loaded successfully")
except FileNotFoundError:
    logger.error("Model weights file not found")
    raise FileNotFoundError("Model weights file not found.")
except Exception as e:
    logger.error(f"Error loading weights: {str(e)}")
    raise Exception(f"Error loading weights: {str(e)}")

# Predict function
def predict(img) -> Tuple[Dict, float]:
    try:
        start_time = timer()
        if img is None:
            raise ValueError("Input image is None.")
        img = effnetb2_transforms(img).unsqueeze(0)
        effnetb2.eval()
        with torch.inference_mode():
            pred_probs = torch.softmax(effnetb2(img), dim=1)
        pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
        pred_time = round(timer() - start_time, 5)
        logger.info(f"Prediction completed: {pred_labels_and_probs}, Time: {pred_time}")
        return pred_labels_and_probs, pred_time
    except Exception as e:
        logger.error(f"Prediction failed: {str(e)}")
        return {"error": f"Prediction failed: {str(e)}"}, 0.0

# Gradio app
title = "FoodVision 101 ๐Ÿ”๐Ÿ‘"
description = "An EfficientNetB2 feature extractor to classify 101 food classes."

try:
    example_list = [["examples/" + example] for example in os.listdir("examples")]
    logger.info("Examples loaded successfully")
except FileNotFoundError:
    example_list = []
    logger.warning("'examples/' directory not found")

# Simplified Gradio interface
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=[
        gr.Label(num_top_classes=5, label="Predictions"),
        gr.Number(label="Prediction time (s)"),
    ],
    examples=example_list,
    title=title,
    description=description,
    allow_flagging="never",  # Disable flagging to simplify API
    api_mode=False,          # Disable API mode to avoid schema generation
)

# Launch with share=True for Hugging Face Spaces
try:
    demo.launch(share=True)
    logger.info("Gradio app launched successfully")
except Exception as e:
    logger.error(f"Failed to launch Gradio app: {str(e)}")
    raise Exception(f"Failed to launch Gradio app: {str(e)}")