Abhinav2809's picture
Rename gradio.app.py to app.py
de042ce verified
import gradio as gr
import torch
import time
from PIL import Image, ImageDraw, ImageFont
from transformers import (
AutoProcessor,
Owlv2ForObjectDetection,
Qwen2VLForConditionalGeneration,
AutoTokenizer,
AutoProcessor
)
# Initialize models
obj_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
obj_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")
cbt_model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-2B-Instruct",
torch_dtype="auto",
device_map="auto",
)
cbt_processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
colors = [
(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 165, 0), (75, 0, 130),
(255, 255, 0), (0, 255, 255), (255, 105, 180), (138, 43, 226), (0, 128, 0),
(0, 128, 128), (255, 20, 147), (64, 224, 208), (128, 0, 128), (70, 130, 180),
(220, 20, 60), (255, 140, 0), (34, 139, 34), (218, 112, 214), (255, 99, 71),
(47, 79, 79), (186, 85, 211), (240, 230, 140), (169, 169, 169), (199, 21, 133)
]
history = [
{
"role": "system",
"content": [
{
"type": "image",
},
{
"type": "text",
"text": "You are an conversation image recognition chatbot. Communicate with humans using natural language. Recognize the images, have a spatial understanding and answer the questions in a concise manner. Generate the best response for a user query. It must be correct lexically and grammatically.",
}
]
}
]
def detect_objects(image, objects):
texts = [objects]
inputs = obj_processor(text=texts, images=image, return_tensors="pt")
with torch.no_grad():
outputs = obj_model(**inputs)
target_sizes = torch.Tensor([image.size[::-1]])
results = obj_processor.post_process_object_detection(
outputs=outputs, threshold=0.2, target_sizes=target_sizes
)
i = 0
text = texts[i]
boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
return image, boxes, scores, labels
def annotate_image(image, boxes, scores, labels, objects):
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
for i, (box, score, label) in enumerate(zip(boxes, scores, labels)):
box = [round(coord, 2) for coord in box.tolist()]
color = colors[label % len(colors)]
draw.rectangle(box, outline=color, width=3)
draw.text((box[0], box[1]), f"{objects[label]}: {score:.2f}", font=font, fill=color)
return image
def run_object_detection(image, objects):
object_list = [obj.strip() for obj in objects.split(",")]
image, boxes, scores, labels = detect_objects(image, object_list)
annotated_image = annotate_image(image, boxes, scores, labels, object_list)
history.append({
'role': 'system',
'content': [
{
'type': 'text',
'text': f'In the image the objects detected are {labels}'
}
]
})
return annotated_image
def user(message, chat_history):
return "", chat_history + [[message, ""]]
def chat_function(image, chat_history):
message = ''
if chat_history[-1][0] is not None:
message = str(chat_history[-1][0])
history.append({
"role": "user",
"content": [
{
"type": "text",
"text": message
}
]
})
text_prompt = cbt_processor.apply_chat_template(history, add_generation_prompt=True)
inputs = cbt_processor(
text=[text_prompt],
images=[image],
padding=True,
return_tensors="pt"
)
inputs = inputs.to("cuda" if torch.cuda.is_available() else "cpu")
output_ids = cbt_model.generate(**inputs, max_new_tokens=1024)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
bot_output = cbt_processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
history.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": str(bot_output)
}
]
})
bot_output_str = str(bot_output).replace('"', '').replace('[', '').replace(']', '').replace("\n", "<br>")
chat_history[-1][1] = ""
for character in bot_output_str:
chat_history[-1][1] += character
time.sleep(0.05)
yield chat_history
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Upload an Image")
image_input = gr.Image(type="pil", label="Upload your image here")
objects_input = gr.Textbox(label="Enter the objects to detect (comma-separated)", placeholder="e.g. 'cat, dog, car'")
image_output = gr.Image(type="pil", label="Detected Objects")
detect_button = gr.Button("Detect Objects")
detect_button.click(fn=run_object_detection, inputs=[image_input, objects_input], outputs=image_output)
with gr.Column(scale=2):
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.ClearButton([msg, chatbot])
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
chat_function, [image_input, chatbot], [chatbot]
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch()