Food_Vision_101 / app.py
Shriharsh's picture
Update app.py
e9bb170 verified
import gradio as gr
import os
import torch
from model import create_effnetb2_model
from timeit import default_timer as timer
from typing import Tuple, Dict
import pkg_resources
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Check Gradio version
try:
gradio_version = pkg_resources.get_distribution("gradio").version
logger.info(f"Using Gradio version: {gradio_version}")
except pkg_resources.DistributionNotFound:
raise ImportError("Gradio is not installed. Please install it using 'pip install gradio'.")
# Load class names
try:
with open("class_names.txt", "r") as f:
class_names = [food_name.strip() for food_name in f.readlines()]
logger.info("Class names loaded successfully")
except FileNotFoundError:
logger.error("class_names.txt not found")
raise FileNotFoundError("class_names.txt not found.")
# Model and transforms preparation
try:
effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=101)
logger.info("EfficientNetB2 model created successfully")
except Exception as e:
logger.error(f"Error creating model: {str(e)}")
raise Exception(f"Error creating model: {str(e)}")
# Load weights
try:
effnetb2.load_state_dict(
torch.load(
"09_pretrained_effnetb2_feature_extractor_food101.pth",
map_location=torch.device("cpu"),
)
)
logger.info("Model weights loaded successfully")
except FileNotFoundError:
logger.error("Model weights file not found")
raise FileNotFoundError("Model weights file not found.")
except Exception as e:
logger.error(f"Error loading weights: {str(e)}")
raise Exception(f"Error loading weights: {str(e)}")
# Predict function
def predict(img) -> Tuple[Dict, float]:
try:
start_time = timer()
if img is None:
raise ValueError("Input image is None.")
img = effnetb2_transforms(img).unsqueeze(0)
effnetb2.eval()
with torch.inference_mode():
pred_probs = torch.softmax(effnetb2(img), dim=1)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
pred_time = round(timer() - start_time, 5)
logger.info(f"Prediction completed: {pred_labels_and_probs}, Time: {pred_time}")
return pred_labels_and_probs, pred_time
except Exception as e:
logger.error(f"Prediction failed: {str(e)}")
return {"error": f"Prediction failed: {str(e)}"}, 0.0
# Gradio app
title = "FoodVision 101 ๐Ÿ”๐Ÿ‘"
description = "An EfficientNetB2 feature extractor to classify 101 food classes."
try:
example_list = [["examples/" + example] for example in os.listdir("examples")]
logger.info("Examples loaded successfully")
except FileNotFoundError:
example_list = []
logger.warning("'examples/' directory not found")
# Simplified Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=5, label="Predictions"),
gr.Number(label="Prediction time (s)"),
],
examples=example_list,
title=title,
description=description,
allow_flagging="never", # Disable flagging to simplify API
api_mode=False, # Disable API mode to avoid schema generation
)
# Launch with share=True for Hugging Face Spaces
try:
demo.launch(share=True)
logger.info("Gradio app launched successfully")
except Exception as e:
logger.error(f"Failed to launch Gradio app: {str(e)}")
raise Exception(f"Failed to launch Gradio app: {str(e)}")