import gradio as gr from transformers import AutoModel, AutoTokenizer import torch import json import requests from PIL import Image from torchvision import transforms import urllib.request from torchvision import models import torch.nn as nn # --- Define the Model --- class FineGrainedClassifier(nn.Module): def __init__(self, num_classes=434): # Updated to 434 classes super(FineGrainedClassifier, self).__init__() self.image_encoder = models.resnet50(pretrained=True) self.image_encoder.fc = nn.Identity() self.text_encoder = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en') self.classifier = nn.Sequential( nn.Linear(2048 + 768, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3), nn.Linear(512, num_classes) # Updated to 434 classes ) def forward(self, image, input_ids, attention_mask): image_features = self.image_encoder(image) text_output = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask) text_features = text_output.last_hidden_state[:, 0, :] combined_features = torch.cat((image_features, text_features), dim=1) output = self.classifier(combined_features) return output # --- Data Augmentation Setup --- transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load the label-to-class mapping from your Hugging Face repository label_map_url = "https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/label_to_class.json" label_to_class = requests.get(label_map_url).json() # Load your custom model from Hugging Face model = FineGrainedClassifier(num_classes=len(label_to_class)) checkpoint_url = f"https://huggingface.co/Maverick98/EcommerceClassifier/resolve/main/model_checkpoint.pth" checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device('cpu')) # Strip the "module." prefix from the keys in the state_dict if they exist # Clean up the state dictionary state_dict = checkpoint.get('model_state_dict', checkpoint) new_state_dict = {} for k, v in state_dict.items(): if k.startswith("module."): new_key = k[7:] # Remove "module." prefix else: new_key = k # Check if the new_key exists in the model's state_dict, only add if it does if new_key in model.state_dict(): new_state_dict[new_key] = v model.load_state_dict(new_state_dict) # Load the tokenizer from Jina tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en") def load_image(image_path_or_url): if isinstance(image_path_or_url, str) and image_path_or_url.startswith("http"): with urllib.request.urlopen(image_path_or_url) as url: image = Image.open(url).convert('RGB') else: image = Image.open(image_path_or_url).convert('RGB') image = transform(image) image = image.unsqueeze(0) # Add batch dimension return image def predict(image_path_or_file, title, threshold=0.4): # Validation: Check if the title is empty or has fewer than 3 words if not title or len(title.split()) < 3: raise gr.Error("Title must be at least 3 words long. Please provide a valid title.") # Preprocess the image image = load_image(image_path_or_file) # Tokenize title title_encoding = tokenizer(title, padding='max_length', max_length=200, truncation=True, return_tensors='pt') input_ids = title_encoding['input_ids'] attention_mask = title_encoding['attention_mask'] # Predict model.eval() with torch.no_grad(): output = model(image, input_ids=input_ids, attention_mask=attention_mask) probabilities = torch.nn.functional.softmax(output, dim=1) top3_probabilities, top3_indices = torch.topk(probabilities, 3, dim=1) # Map the top 3 indices to class names top3_classes = [label_to_class[str(idx.item())] for idx in top3_indices[0]] # Check if the highest probability is below the threshold if top3_probabilities[0][0].item() < threshold: top3_classes.insert(0, "Others") top3_probabilities = torch.cat((torch.tensor([[1.0 - top3_probabilities[0][0].item()]]), top3_probabilities), dim=1) # Prepare the output as a dictionary results = {} for i in range(len(top3_classes)): results[top3_classes[i]] = top3_probabilities[0][i].item() return results # Define the Gradio interface title_input = gr.Textbox(label="Product Title", placeholder="Enter the product title here...") image_input = gr.Image(type="filepath", label="Upload Image or Provide URL") output = gr.JSON(label="Top 3 Predictions with Probabilities") gr.Interface( fn=predict, inputs=[image_input, title_input], outputs=output, title="Ecommerce Classifier", description="This model classifies ecommerce products into one of 434 categories. If the model is unsure, it outputs 'Others'.", ).launch(share=True)