import gradio as gr import torch import time from PIL import Image, ImageDraw, ImageFont from transformers import ( AutoProcessor, Owlv2ForObjectDetection, Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor ) # Initialize models obj_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble") obj_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble") cbt_model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", torch_dtype="auto", device_map="auto", ) cbt_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") colors = [ (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (75, 0, 130), (255, 255, 0), (0, 255, 255), (255, 105, 180), (138, 43, 226), (0, 128, 0), (0, 128, 128), (255, 20, 147), (64, 224, 208), (128, 0, 128), (70, 130, 180), (220, 20, 60), (255, 140, 0), (34, 139, 34), (218, 112, 214), (255, 99, 71), (47, 79, 79), (186, 85, 211), (240, 230, 140), (169, 169, 169), (199, 21, 133) ] history = [ { "role": "system", "content": [ { "type": "image", }, { "type": "text", "text": "You are an conversation image recognition chatbot. Communicate with humans using natural language. Recognize the images, have a spatial understanding and answer the questions in a concise manner. Generate the best response for a user query. It must be correct lexically and grammatically.", } ] } ] def detect_objects(image, objects): texts = [objects] inputs = obj_processor(text=texts, images=image, return_tensors="pt") with torch.no_grad(): outputs = obj_model(**inputs) target_sizes = torch.Tensor([image.size[::-1]]) results = obj_processor.post_process_object_detection( outputs=outputs, threshold=0.2, target_sizes=target_sizes ) i = 0 text = texts[i] boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"] return image, boxes, scores, labels def annotate_image(image, boxes, scores, labels, objects): draw = ImageDraw.Draw(image) font = ImageFont.load_default() for i, (box, score, label) in enumerate(zip(boxes, scores, labels)): box = [round(coord, 2) for coord in box.tolist()] color = colors[label % len(colors)] draw.rectangle(box, outline=color, width=3) draw.text((box[0], box[1]), f"{objects[label]}: {score:.2f}", font=font, fill=color) return image def run_object_detection(image, objects): object_list = [obj.strip() for obj in objects.split(",")] image, boxes, scores, labels = detect_objects(image, object_list) annotated_image = annotate_image(image, boxes, scores, labels, object_list) history.append({ 'role': 'system', 'content': [ { 'type': 'text', 'text': f'In the image the objects detected are {labels}' } ] }) return annotated_image def user(message, chat_history): return "", chat_history + [[message, ""]] def chat_function(image, chat_history): message = '' if chat_history[-1][0] is not None: message = str(chat_history[-1][0]) history.append({ "role": "user", "content": [ { "type": "text", "text": message } ] }) text_prompt = cbt_processor.apply_chat_template(history, add_generation_prompt=True) inputs = cbt_processor( text=[text_prompt], images=[image], padding=True, return_tensors="pt" ) inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu") output_ids = cbt_model.generate(**inputs, max_new_tokens=1024) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids) ] bot_output = cbt_processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) history.append({ "role": "assistant", "content": [ { "type": "text", "text": str(bot_output) } ] }) bot_output_str = str(bot_output).replace('"', '').replace('[', '').replace(']', '').replace("\n", "
") chat_history[-1][1] = "" for character in bot_output_str: chat_history[-1][1] += character time.sleep(0.05) yield chat_history with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=1): gr.Markdown("## Upload an Image") image_input = gr.Image(type="pil", label="Upload your image here") objects_input = gr.Textbox(label="Enter the objects to detect (comma-separated)", placeholder="e.g. 'cat, dog, car'") image_output = gr.Image(type="pil", label="Detected Objects") detect_button = gr.Button("Detect Objects") detect_button.click(fn=run_object_detection, inputs=[image_input, objects_input], outputs=image_output) with gr.Column(scale=2): chatbot = gr.Chatbot() msg = gr.Textbox() clear = gr.ClearButton([msg, chatbot]) msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( chat_function, [image_input, chatbot], [chatbot] ) clear.click(lambda: None, None, chatbot, queue=False) demo.launch()