File size: 5,538 Bytes
6d4e500 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import gradio as gr
import torch
import time
from PIL import Image, ImageDraw, ImageFont
from transformers import (
AutoProcessor,
Owlv2ForObjectDetection,
Qwen2VLForConditionalGeneration,
AutoTokenizer,
AutoProcessor
)
# Initialize models
obj_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
obj_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
cbt_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
torch_dtype="auto",
device_map="auto",
)
cbt_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
colors = [
(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (75, 0, 130),
(255, 255, 0), (0, 255, 255), (255, 105, 180), (138, 43, 226), (0, 128, 0),
(0, 128, 128), (255, 20, 147), (64, 224, 208), (128, 0, 128), (70, 130, 180),
(220, 20, 60), (255, 140, 0), (34, 139, 34), (218, 112, 214), (255, 99, 71),
(47, 79, 79), (186, 85, 211), (240, 230, 140), (169, 169, 169), (199, 21, 133)
]
history = [
{
"role": "system",
"content": [
{
"type": "image",
},
{
"type": "text",
"text": "You are an conversation image recognition chatbot. Communicate with humans using natural language. Recognize the images, have a spatial understanding and answer the questions in a concise manner. Generate the best response for a user query. It must be correct lexically and grammatically.",
}
]
}
]
def detect_objects(image, objects):
texts = [objects]
inputs = obj_processor(text=texts, images=image, return_tensors="pt")
with torch.no_grad():
outputs = obj_model(**inputs)
target_sizes = torch.Tensor([image.size[::-1]])
results = obj_processor.post_process_object_detection(
outputs=outputs, threshold=0.2, target_sizes=target_sizes
)
i = 0
text = texts[i]
boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
return image, boxes, scores, labels
def annotate_image(image, boxes, scores, labels, objects):
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
for i, (box, score, label) in enumerate(zip(boxes, scores, labels)):
box = [round(coord, 2) for coord in box.tolist()]
color = colors[label % len(colors)]
draw.rectangle(box, outline=color, width=3)
draw.text((box[0], box[1]), f"{objects[label]}: {score:.2f}", font=font, fill=color)
return image
def run_object_detection(image, objects):
object_list = [obj.strip() for obj in objects.split(",")]
image, boxes, scores, labels = detect_objects(image, object_list)
annotated_image = annotate_image(image, boxes, scores, labels, object_list)
history.append({
'role': 'system',
'content': [
{
'type': 'text',
'text': f'In the image the objects detected are {labels}'
}
]
})
return annotated_image
def user(message, chat_history):
return "", chat_history + [[message, ""]]
def chat_function(image, chat_history):
message = ''
if chat_history[-1][0] is not None:
message = str(chat_history[-1][0])
history.append({
"role": "user",
"content": [
{
"type": "text",
"text": message
}
]
})
text_prompt = cbt_processor.apply_chat_template(history, add_generation_prompt=True)
inputs = cbt_processor(
text=[text_prompt],
images=[image],
padding=True,
return_tensors="pt"
)
inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
output_ids = cbt_model.generate(**inputs, max_new_tokens=1024)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
bot_output = cbt_processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
history.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": str(bot_output)
}
]
})
bot_output_str = str(bot_output).replace('"', '').replace('[', '').replace(']', '').replace("\n", "<br>")
chat_history[-1][1] = ""
for character in bot_output_str:
chat_history[-1][1] += character
time.sleep(0.05)
yield chat_history
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Upload an Image")
image_input = gr.Image(type="pil", label="Upload your image here")
objects_input = gr.Textbox(label="Enter the objects to detect (comma-separated)", placeholder="e.g. 'cat, dog, car'")
image_output = gr.Image(type="pil", label="Detected Objects")
detect_button = gr.Button("Detect Objects")
detect_button.click(fn=run_object_detection, inputs=[image_input, objects_input], outputs=image_output)
with gr.Column(scale=2):
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.ClearButton([msg, chatbot])
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
chat_function, [image_input, chatbot], [chatbot]
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch() |