File size: 5,709 Bytes
3ce1aa7
 
 
 
 
8064ded
 
 
 
 
3ce1aa7
9711b09
 
 
 
 
 
3ce1aa7
 
 
 
 
 
3fa08b4
8064ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ce1aa7
 
3fa08b4
3ce1aa7
 
 
3fa08b4
3ce1aa7
 
 
 
 
 
3fa08b4
3ce1aa7
 
 
 
 
8064ded
 
3ce1aa7
 
 
 
 
9711b09
 
 
 
3fa08b4
9711b09
3ce1aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d00962d
 
3ce1aa7
d00962d
3ce1aa7
8064ded
 
 
d00962d
8064ded
3ce1aa7
 
 
 
 
 
 
9711b09
3ce1aa7
 
eed6859
8064ded
 
d00962d
8064ded
 
3ce1aa7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
from transformers import AutoTokenizer, VisionEncoderDecoderModel, AutoImageProcessor
from PIL import Image
from torchvision.transforms.functional import crop
import gradio as gr
import base64
import io
from huggingface_hub import hf_hub_download
import zipfile
import os

# Global variables for models
object_detection_model = None
captioning_model = None
tokenizer = None
captioning_processor = None

# Load models during initialization
def init():
    global object_detection_model, captioning_model, tokenizer, captioning_processor

    # Step 1: Load the YOLOv5 model from Hugging Face
    try:
        print("Loading YOLOv5 model...")
        # Get Hugging Face auth token from environment variable
        auth_token = os.getenv("HF_AUTH_TOKEN")
        if not auth_token:
            print("Error: HF_AUTH_TOKEN environment variable not set.")
            object_detection_model = None
        else:
            # Download the zip file from Hugging Face
            zip_path = hf_hub_download(repo_id='Mexbow/Yolov5_object_detection', filename='yolov5.zip', use_auth_token=auth_token)
            
            # Extract the YOLOv5 model
            extract_path = './yolov5_model'  # Specify extraction path
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                os.makedirs(extract_path, exist_ok=True)
                zip_ref.extractall(extract_path)
            
            # Load the YOLOv5 model
            model_path = os.path.join(extract_path, 'yolov5/weights/best14.pt')
            if not os.path.exists(model_path):
                print(f"Error: YOLOv5 model file not found at {model_path}")
                object_detection_model = None
            else:
                object_detection_model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path, trust_repo=True)
                print("YOLOv5 model loaded successfully.")
    except Exception as e:
        print(f"Error loading YOLOv5 model: {e}")
        object_detection_model = None

    # Step 2: Load the ViT-GPT2 captioning model from Hugging Face
    try:
        print("Loading ViT-GPT2 model...")
        captioning_model = VisionEncoderDecoderModel.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
        tokenizer = AutoTokenizer.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
        captioning_processor = AutoImageProcessor.from_pretrained("motheecreator/ViT-GPT2-Image-Captioning")
        print("ViT-GPT2 model loaded successfully.")
    except Exception as e:
        print(f"Error loading captioning model: {e}")
        captioning_model, tokenizer, captioning_processor = None, None, None

# Utility function to crop objects from the image based on bounding boxes
def crop_objects(image, boxes):
    cropped_images = []
    for box in boxes:
        left, top, right, bottom = box
        cropped_image = image.crop((left, top, right, bottom))
        cropped_images.append(cropped_image)
    return cropped_images

# Gradio interface function
def process_image(image):
    global object_detection_model, captioning_model, tokenizer, captioning_processor

    # Ensure models are loaded
    if object_detection_model is None or captioning_model is None or tokenizer is None or captioning_processor is None:
        return None, {"error": "Models are not loaded properly"}, None

    try:
        # Step 1: Perform object detection with YOLOv5
        results = object_detection_model(image)
        boxes = results.xyxy[0][:, :4].cpu().numpy()  # Bounding boxes
        labels = [results.names[int(class_id)] for class_id in results.xyxy[0][:, 5].cpu().numpy().astype(int)]  # Class names
        scores = results.xyxy[0][:, 4].cpu().numpy()  # Confidence scores

        # Step 2: Generate caption for the whole image
        original_inputs = captioning_processor(images=image, return_tensors="pt")
        with torch.no_grad():
            original_caption_ids = captioning_model.generate(**original_inputs)
        original_caption = tokenizer.decode(original_caption_ids[0], skip_special_tokens=True)

        # Step 3: Crop detected objects and generate captions for each object
        cropped_images = crop_objects(image, boxes)
        captions = []
        for cropped_image in cropped_images:
            inputs = captioning_processor(images=cropped_image, return_tensors="pt")
            with torch.no_grad():
                caption_ids = captioning_model.generate(**inputs)
            caption = tokenizer.decode(caption_ids[0], skip_special_tokens=True)
            captions.append(caption)

        # Prepare the result for visualization as a formatted string
        detection_results = ""
        for i, (label, box, score, caption) in enumerate(zip(labels, boxes, scores, captions)):
            detection_results += f"Object {i + 1}: {label} - Caption: {caption}\n"

        # Render image with bounding boxes
        result_image = results.render()[0]

        # Return the image with detections, formatted captions, and the whole image caption
        return result_image, detection_results, original_caption

    except Exception as e:
        return None, {"error": str(e)}, None

# Initialize models
init()

# Gradio Interface
interface = gr.Interface(
    fn=process_image,  # Function to run
    inputs=gr.Image(type="pil"),  # Input: Image upload
    outputs=[
        gr.Image(type="pil", label="Detected Objects"),  # Output 1: Image with bounding boxes
        gr.Textbox(label="Object Captions & Bounding Boxes", lines=10),  # Output 2: Formatted captions
        gr.Textbox(label="Whole Image Caption")  # Output 3: Caption for the whole image
    ],
    live=True
)

# Launch the Gradio app
interface.launch()