import gradio as gr import torch import torchvision.models as models from torchvision.models import EfficientNet_B0_Weights # Or the specific version used from PIL import Image from torchvision import transforms import json from huggingface_hub import hf_hub_download import os # --- Configuration --- # This should be the ID of the repository where your MODEL is stored MODEL_REPO_ID = "bhumong/fruit-classifier-efficientnet-b0" # <-- REPLACE if different MODEL_FILENAME = "pytorch_model.bin" CONFIG_FILENAME = "config.json" # --- 1. Load Model and Config --- # (Using the function defined previously to load from Hub) def load_model_from_hf(repo_id, model_filename, config_filename): """Loads model state_dict and config from Hugging Face Hub.""" try: config_path = hf_hub_download(repo_id=repo_id, filename=config_filename) with open(config_path, 'r') as f: config = json.load(f) print("Config loaded:", config) # Debug print except Exception as e: print(f"Error loading config from {repo_id}/{config_filename}: {e}") raise # Re-raise error if config fails num_labels = config.get('num_labels') id2label = config.get('id2label') if num_labels is None or id2label is None: raise ValueError("Config file must contain 'num_labels' and 'id2label'") # Instantiate the correct architecture (EfficientNet-B0) model = models.efficientnet_b0(weights=None) # Load architecture only # Modify the classifier head try: num_ftrs = model.classifier[1].in_features model.classifier[1] = torch.nn.Linear(num_ftrs, num_labels) except Exception as e: print(f"Error modifying model classifier: {e}") raise # Download and load model weights try: model_path = hf_hub_download(repo_id=repo_id, filename=model_filename) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) model.to(device) # Move model to device model.eval() # Set to evaluation mode print(f"Model loaded successfully from {repo_id} to device {device}.") return model, config, id2label, device except Exception as e: print(f"Error loading model weights from {repo_id}/{model_filename}: {e}") raise # Load the model globally when the script starts try: model, config, id2label, device = load_model_from_hf(MODEL_REPO_ID, MODEL_FILENAME, CONFIG_FILENAME) except Exception as e: print(f"FATAL: Could not load model or config. Gradio app cannot start. Error: {e}") # Optionally, exit or raise a specific error for Gradio to catch if possible model, config, id2label, device = None, None, None, None # Prevent further errors # --- 2. Define Preprocessing --- IMG_SIZE = (224, 224) mean=[0.485, 0.456, 0.406] std=[0.229, 0.224, 0.225] preprocess = transforms.Compose([ transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) # --- 3. Define Prediction Function --- def predict(inp_image): """Takes a PIL image, preprocesses, predicts, and returns label confidences.""" if model is None or id2label is None: return {"Error": 1.0, "Message": "Model not loaded"} # Handle model load failure if inp_image is None: return {"Error": 1.0, "Message": "No image provided"} try: # Ensure image is RGB img = inp_image.convert("RGB") input_tensor = preprocess(img) input_batch = input_tensor.unsqueeze(0) # Add batch dimension input_batch = input_batch.to(device) # Move tensor to the correct device with torch.no_grad(): output = model(input_batch) probabilities = torch.nn.functional.softmax(output[0], dim=0) # Prepare output for Gradio Label component (dictionary {label: probability}) confidences = {id2label[str(i)]: float(probabilities[i]) for i in range(len(id2label))} return confidences except Exception as e: print(f"Error during prediction: {e}") return {"Error": 1.0, "Message": f"Prediction failed: {e}"} # --- 4. Create Gradio Interface --- # Add example images (Make sure these paths exist within your Space repo!) # Create an 'images' folder in your Space and upload some examples. example_list = [ ["images/example_apple.jpg"], # <-- REPLACE with actual paths in your Space repo ["images/example_banana.jpg"], # <-- REPLACE with actual paths in your Space repo ["images/example_strawberry.jpg"] # <-- REPLACE with actual paths in your Space repo ] # Check if example files exist, otherwise provide empty list if not all(os.path.exists(ex[0]) for ex in example_list): print("Warning: Example image paths not found. Clearing examples.") example_list = [] # Define Title, Description, and Article for the Gradio app title = "Fruit Classifier 🍎🍌🍓" description = """ Upload an image of a fruit or use one of the examples below. This demo uses an EfficientNet-B0 model fine-tuned on the Fruits-360 dataset (with merged classes) to predict the fruit type. Model hosted on Hugging Face Hub: [{MODEL_REPO_ID}](https://huggingface.co/{MODEL_REPO_ID}) """.format(MODEL_REPO_ID=MODEL_REPO_ID) # Format description with repo ID article = """