Spaces:
Sleeping
Sleeping
File size: 5,709 Bytes
3ce1aa7 8064ded 3ce1aa7 9711b09 3ce1aa7 3fa08b4 8064ded 3ce1aa7 3fa08b4 3ce1aa7 3fa08b4 3ce1aa7 3fa08b4 3ce1aa7 8064ded 3ce1aa7 9711b09 3fa08b4 9711b09 3ce1aa7 d00962d 3ce1aa7 d00962d 3ce1aa7 8064ded d00962d 8064ded 3ce1aa7 9711b09 3ce1aa7 eed6859 8064ded d00962d 8064ded 3ce1aa7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor
from PIL import Image
from torchvision.transforms.functional import crop
import gradio as gr
import base64
import io
from huggingface_hub import hf_hub_download
import zipfile
import os
# Global variables for models
object_detection_model = None
captioning_model = None
tokenizer = None
captioning_processor = None
# Load models during initialization
def init():
global object_detection_model, captioning_model, tokenizer, captioning_processor
# Step 1: Load the YOLOv5 model from Hugging Face
try:
print("Loading YOLOv5 model...")
# Get Hugging Face auth token from environment variable
auth_token = os.getenv("HF_AUTH_TOKEN")
if not auth_token:
print("Error: HF_AUTH_TOKEN environment variable not set.")
object_detection_model = None
else:
# Download the zip file from Hugging Face
zip_path = hf_hub_download(repo_id='Mexbow/Yolov5_object_detection', filename='yolov5.zip', use_auth_token=auth_token)
# Extract the YOLOv5 model
extract_path = './yolov5_model' # Specify extraction path
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
os.makedirs(extract_path, exist_ok=True)
zip_ref.extractall(extract_path)
# Load the YOLOv5 model
model_path = os.path.join(extract_path, 'yolov5/weights/best14.pt')
if not os.path.exists(model_path):
print(f"Error: YOLOv5 model file not found at {model_path}")
object_detection_model = None
else:
object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, trust_repo=True)
print("YOLOv5 model loaded successfully.")
except Exception as e:
print(f"Error loading YOLOv5 model: {e}")
object_detection_model = None
# Step 2: Load the ViT-GPT2 captioning model from Hugging Face
try:
print("Loading ViT-GPT2 model...")
captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
print("ViT-GPT2 model loaded successfully.")
except Exception as e:
print(f"Error loading captioning model: {e}")
captioning_model, tokenizer, captioning_processor = None, None, None
# Utility function to crop objects from the image based on bounding boxes
def crop_objects(image, boxes):
cropped_images = []
for box in boxes:
left, top, right, bottom = box
cropped_image = image.crop((left, top, right, bottom))
cropped_images.append(cropped_image)
return cropped_images
# Gradio interface function
def process_image(image):
global object_detection_model, captioning_model, tokenizer, captioning_processor
# Ensure models are loaded
if object_detection_model is None or captioning_model is None or tokenizer is None or captioning_processor is None:
return None, {"error": "Models are not loaded properly"}, None
try:
# Step 1: Perform object detection with YOLOv5
results = object_detection_model(image)
boxes = results.xyxy[0][:, :4].cpu().numpy() # Bounding boxes
labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)] # Class names
scores = results.xyxy[0][:, 4].cpu().numpy() # Confidence scores
# Step 2: Generate caption for the whole image
original_inputs = captioning_processor(images=image, return_tensors="pt")
with torch.no_grad():
original_caption_ids = captioning_model.generate(**original_inputs)
original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True)
# Step 3: Crop detected objects and generate captions for each object
cropped_images = crop_objects(image, boxes)
captions = []
for cropped_image in cropped_images:
inputs = captioning_processor(images=cropped_image, return_tensors="pt")
with torch.no_grad():
caption_ids = captioning_model.generate(**inputs)
caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True)
captions.append(caption)
# Prepare the result for visualization as a formatted string
detection_results = ""
for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)):
detection_results += f"Object {i + 1}: {label} - Caption: {caption}\n"
# Render image with bounding boxes
result_image = results.render()[0]
# Return the image with detections, formatted captions, and the whole image caption
return result_image, detection_results, original_caption
except Exception as e:
return None, {"error": str(e)}, None
# Initialize models
init()
# Gradio Interface
interface = gr.Interface(
fn=process_image, # Function to run
inputs=gr.Image(type="pil"), # Input: Image upload
outputs=[
gr.Image(type="pil", label="Detected Objects"), # Output 1: Image with bounding boxes
gr.Textbox(label="Object Captions & Bounding Boxes", lines=10), # Output 2: Formatted captions
gr.Textbox(label="Whole Image Caption") # Output 3: Caption for the whole image
],
live=True
)
# Launch the Gradio app
interface.launch()
|