Abhinav2809 commited on
Commit
6d4e500
·
verified ·
1 Parent(s): d353640

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -0
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import time
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ from transformers import (
6
+ AutoProcessor,
7
+ Owlv2ForObjectDetection,
8
+ Qwen2VLForConditionalGeneration,
9
+ AutoTokenizer,
10
+ AutoProcessor
11
+ )
12
+
13
+ # Initialize models
14
+ obj_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
15
+ obj_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
16
+
17
+ cbt_model = Qwen2VLForConditionalGeneration.from_pretrained(
18
+ "Qwen/Qwen2-VL-2B-Instruct",
19
+ torch_dtype="auto",
20
+ device_map="auto",
21
+ )
22
+ cbt_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
23
+
24
+ colors = [
25
+ (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (75, 0, 130),
26
+ (255, 255, 0), (0, 255, 255), (255, 105, 180), (138, 43, 226), (0, 128, 0),
27
+ (0, 128, 128), (255, 20, 147), (64, 224, 208), (128, 0, 128), (70, 130, 180),
28
+ (220, 20, 60), (255, 140, 0), (34, 139, 34), (218, 112, 214), (255, 99, 71),
29
+ (47, 79, 79), (186, 85, 211), (240, 230, 140), (169, 169, 169), (199, 21, 133)
30
+ ]
31
+
32
+ history = [
33
+ {
34
+ "role": "system",
35
+ "content": [
36
+ {
37
+ "type": "image",
38
+ },
39
+ {
40
+ "type": "text",
41
+ "text": "You are an conversation image recognition chatbot. Communicate with humans using natural language. Recognize the images, have a spatial understanding and answer the questions in a concise manner. Generate the best response for a user query. It must be correct lexically and grammatically.",
42
+ }
43
+ ]
44
+ }
45
+ ]
46
+
47
+ def detect_objects(image, objects):
48
+ texts = [objects]
49
+ inputs = obj_processor(text=texts, images=image, return_tensors="pt")
50
+
51
+ with torch.no_grad():
52
+ outputs = obj_model(**inputs)
53
+
54
+ target_sizes = torch.Tensor([image.size[::-1]])
55
+ results = obj_processor.post_process_object_detection(
56
+ outputs=outputs, threshold=0.2, target_sizes=target_sizes
57
+ )
58
+
59
+ i = 0
60
+ text = texts[i]
61
+ boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
62
+ return image, boxes, scores, labels
63
+
64
+ def annotate_image(image, boxes, scores, labels, objects):
65
+ draw = ImageDraw.Draw(image)
66
+ font = ImageFont.load_default()
67
+
68
+ for i, (box, score, label) in enumerate(zip(boxes, scores, labels)):
69
+ box = [round(coord, 2) for coord in box.tolist()]
70
+ color = colors[label % len(colors)]
71
+ draw.rectangle(box, outline=color, width=3)
72
+ draw.text((box[0], box[1]), f"{objects[label]}: {score:.2f}", font=font, fill=color)
73
+
74
+ return image
75
+
76
+ def run_object_detection(image, objects):
77
+ object_list = [obj.strip() for obj in objects.split(",")]
78
+ image, boxes, scores, labels = detect_objects(image, object_list)
79
+ annotated_image = annotate_image(image, boxes, scores, labels, object_list)
80
+ history.append({
81
+ 'role': 'system',
82
+ 'content': [
83
+ {
84
+ 'type': 'text',
85
+ 'text': f'In the image the objects detected are {labels}'
86
+ }
87
+ ]
88
+ })
89
+ return annotated_image
90
+
91
+ def user(message, chat_history):
92
+ return "", chat_history + [[message, ""]]
93
+
94
+ def chat_function(image, chat_history):
95
+ message = ''
96
+
97
+ if chat_history[-1][0] is not None:
98
+ message = str(chat_history[-1][0])
99
+
100
+ history.append({
101
+ "role": "user",
102
+ "content": [
103
+ {
104
+ "type": "text",
105
+ "text": message
106
+ }
107
+ ]
108
+ })
109
+
110
+ text_prompt = cbt_processor.apply_chat_template(history, add_generation_prompt=True)
111
+
112
+ inputs = cbt_processor(
113
+ text=[text_prompt],
114
+ images=[image],
115
+ padding=True,
116
+ return_tensors="pt"
117
+ )
118
+
119
+ inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
120
+
121
+ output_ids = cbt_model.generate(**inputs, max_new_tokens=1024)
122
+
123
+ generated_ids = [
124
+ output_ids[len(input_ids):]
125
+ for input_ids, output_ids in zip(inputs.input_ids, output_ids)
126
+ ]
127
+
128
+ bot_output = cbt_processor.batch_decode(
129
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
130
+ )
131
+
132
+ history.append({
133
+ "role": "assistant",
134
+ "content": [
135
+ {
136
+ "type": "text",
137
+ "text": str(bot_output)
138
+ }
139
+ ]
140
+ })
141
+
142
+ bot_output_str = str(bot_output).replace('"', '').replace('[', '').replace(']', '').replace("\n", "<br>")
143
+
144
+ chat_history[-1][1] = ""
145
+ for character in bot_output_str:
146
+ chat_history[-1][1] += character
147
+ time.sleep(0.05)
148
+ yield chat_history
149
+
150
+ with gr.Blocks() as demo:
151
+ with gr.Row():
152
+ with gr.Column(scale=1):
153
+ gr.Markdown("## Upload an Image")
154
+ image_input = gr.Image(type="pil", label="Upload your image here")
155
+ objects_input = gr.Textbox(label="Enter the objects to detect (comma-separated)", placeholder="e.g. 'cat, dog, car'")
156
+ image_output = gr.Image(type="pil", label="Detected Objects")
157
+ detect_button = gr.Button("Detect Objects")
158
+ detect_button.click(fn=run_object_detection, inputs=[image_input, objects_input], outputs=image_output)
159
+
160
+ with gr.Column(scale=2):
161
+ chatbot = gr.Chatbot()
162
+ msg = gr.Textbox()
163
+ clear = gr.ClearButton([msg, chatbot])
164
+
165
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
166
+ chat_function, [image_input, chatbot], [chatbot]
167
+ )
168
+ clear.click(lambda: None, None, chatbot, queue=False)
169
+
170
+ demo.launch()