Spaces:
Runtime error
Runtime error
File size: 3,663 Bytes
3b23212 839ec63 e9bb170 839ec63 e9bb170 839ec63 3b23212 752ca4f 4a739e5 752ca4f 4a739e5 e9bb170 4a739e5 e9bb170 752ca4f 3b23212 839ec63 4a739e5 752ca4f e9bb170 4a739e5 e9bb170 4a739e5 3b23212 752ca4f 4a739e5 752ca4f 4a739e5 07270f5 e9bb170 4a739e5 e9bb170 752ca4f 4a739e5 e9bb170 752ca4f 3b23212 839ec63 3b23212 4a739e5 752ca4f 4a739e5 752ca4f 4a739e5 e9bb170 4a739e5 e9bb170 4a739e5 3b23212 839ec63 07270f5 752ca4f 3b23212 4a739e5 e9bb170 4a739e5 e9bb170 3b23212 839ec63 07270f5 e9bb170 07270f5 839ec63 e9bb170 839ec63 e9bb170 839ec63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import gradio as gr
import os
import torch
from model import create_effnetb2_model
from timeit import default_timer as timer
from typing import Tuple, Dict
import pkg_resources
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Check Gradio version
try:
gradio_version = pkg_resources.get_distribution("gradio").version
logger.info(f"Using Gradio version: {gradio_version}")
except pkg_resources.DistributionNotFound:
raise ImportError("Gradio is not installed. Please install it using 'pip install gradio'.")
# Load class names
try:
with open("class_names.txt", "r") as f:
class_names = [food_name.strip() for food_name in f.readlines()]
logger.info("Class names loaded successfully")
except FileNotFoundError:
logger.error("class_names.txt not found")
raise FileNotFoundError("class_names.txt not found.")
# Model and transforms preparation
try:
effnetb2, effnetb2_transforms = create_effnetb2_model(num_classes=101)
logger.info("EfficientNetB2 model created successfully")
except Exception as e:
logger.error(f"Error creating model: {str(e)}")
raise Exception(f"Error creating model: {str(e)}")
# Load weights
try:
effnetb2.load_state_dict(
torch.load(
"09_pretrained_effnetb2_feature_extractor_food101.pth",
map_location=torch.device("cpu"),
)
)
logger.info("Model weights loaded successfully")
except FileNotFoundError:
logger.error("Model weights file not found")
raise FileNotFoundError("Model weights file not found.")
except Exception as e:
logger.error(f"Error loading weights: {str(e)}")
raise Exception(f"Error loading weights: {str(e)}")
# Predict function
def predict(img) -> Tuple[Dict, float]:
try:
start_time = timer()
if img is None:
raise ValueError("Input image is None.")
img = effnetb2_transforms(img).unsqueeze(0)
effnetb2.eval()
with torch.inference_mode():
pred_probs = torch.softmax(effnetb2(img), dim=1)
pred_labels_and_probs = {class_names[i]: float(pred_probs[0][i]) for i in range(len(class_names))}
pred_time = round(timer() - start_time, 5)
logger.info(f"Prediction completed: {pred_labels_and_probs}, Time: {pred_time}")
return pred_labels_and_probs, pred_time
except Exception as e:
logger.error(f"Prediction failed: {str(e)}")
return {"error": f"Prediction failed: {str(e)}"}, 0.0
# Gradio app
title = "FoodVision 101 ๐๐"
description = "An EfficientNetB2 feature extractor to classify 101 food classes."
try:
example_list = [["examples/" + example] for example in os.listdir("examples")]
logger.info("Examples loaded successfully")
except FileNotFoundError:
example_list = []
logger.warning("'examples/' directory not found")
# Simplified Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[
gr.Label(num_top_classes=5, label="Predictions"),
gr.Number(label="Prediction time (s)"),
],
examples=example_list,
title=title,
description=description,
allow_flagging="never", # Disable flagging to simplify API
api_mode=False, # Disable API mode to avoid schema generation
)
# Launch with share=True for Hugging Face Spaces
try:
demo.launch(share=True)
logger.info("Gradio app launched successfully")
except Exception as e:
logger.error(f"Failed to launch Gradio app: {str(e)}")
raise Exception(f"Failed to launch Gradio app: {str(e)}") |