prithivMLmods's picture
initial commit
03c9962 verified
raw
history blame
17.6 kB
import os
import random
import uuid
import json
import time
import asyncio
from threading import Thread
import base64
from io import BytesIO
import re
import gradio as gr
import spaces
import torch
import numpy as np
from PIL import Image, ImageDraw
import cv2
from transformers import (
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
AutoProcessor,
TextIteratorStreamer,
)
from qwen_vl_utils import process_vision_info
# Constants for text generation
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load Camel-Doc-OCR-062825
MODEL_ID_M = "prithivMLmods/Camel-Doc-OCR-062825"
processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True)
model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_M,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
# Load ViLaSR-7B
MODEL_ID_X = "AntResearchNLP/ViLaSR"
processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True)
model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_X,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
# Load OCRFlux-3B
MODEL_ID_T = "ChatDOC/OCRFlux-3B"
processor_t = AutoProcessor.from_pretrained(MODEL_ID_T, trust_remote_code=True)
model_t = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_T,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
# Load ShotVL-7B
MODEL_ID_S = "Vchitect/ShotVL-7B"
processor_s = AutoProcessor.from_pretrained(MODEL_ID_S, trust_remote_code=True)
model_s = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID_S,
trust_remote_code=True,
torch_dtype=torch.float16
).to(device).eval()
# Helper functions for object detection
def image_to_base64(image):
"""Convert a PIL image to a base64-encoded string."""
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
def draw_bounding_boxes(image, bounding_boxes, outline_color="red", line_width=2):
"""Draw bounding boxes on an image."""
draw = ImageDraw.Draw(image)
for box in bounding_boxes:
xmin, ymin, xmax, ymax = box
draw.rectangle([xmin, ymin, xmax, ymax], outline=outline_color, width=line_width)
return image
def rescale_bounding_boxes(bounding_boxes, original_width, original_height, scaled_width=1000, scaled_height=1000):
"""Rescale bounding boxes from normalized (1000x1000) to original image dimensions."""
x_scale = original_width / scaled_width
y_scale = original_height / scaled_height
rescaled_boxes = []
for box in bounding_boxes:
xmin, ymin, xmax, ymax = box
rescaled_box = [
xmin * x_scale,
ymin * y_scale,
xmax * x_scale,
ymax * y_scale
]
rescaled_boxes.append(rescaled_box)
return rescaled_boxes
# Default system prompt for object detection
default_system_prompt = (
"You are a helpful assistant to detect objects in images. When asked to detect elements based on a description, "
"you return bounding boxes for all elements in the form of [xmin, ymin, xmax, ymax] with the values being scaled "
"to 512 by 512 pixels. When there are more than one result, answer with a list of bounding boxes in the form "
"of [[xmin, ymin, xmax, ymax], [xmin, ymin, xmax, ymax], ...]."
"Parse only the boxes; don't write unnecessary content."
)
# Function for object detection
@spaces.GPU
def run_example(image, text_input, system_prompt):
"""Detect objects in an image and return bounding box annotations."""
model = model_x
processor = processor_x
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": f"data:image;base64,{image_to_base64(image)}"},
{"type": "text", "text": system_prompt},
{"type": "text", "text": text_input},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
pattern = r'\[\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\]'
matches = re.findall(pattern, str(output_text))
parsed_boxes = [[int(num) for num in match] for match in matches]
scaled_boxes = rescale_bounding_boxes(parsed_boxes, image.width, image.height)
annotated_image = draw_bounding_boxes(image.copy(), scaled_boxes)
return output_text[0], str(parsed_boxes), annotated_image
def downsample_video(video_path):
"""
Downsample a video to evenly spaced frames, returning each as a PIL image with its timestamp.
"""
vidcap = cv2.VideoCapture(video_path)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = vidcap.get(cv2.CAP_PROP_FPS)
frames = []
frame_indices = np.linspace(0, total_frames - 1, 10, dtype=int)
for i in frame_indices:
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
success, image = vidcap.read()
if success:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
timestamp = round(i / fps, 2)
frames.append((pil_image, timestamp))
vidcap.release()
return frames
@spaces.GPU
def generate_image(model_name: str, text: str, image: Image.Image,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2):
"""
Generate responses using the selected model for image input.
"""
if model_name == "Camel-Doc-OCR-062825":
processor = processor_m
model = model_m
elif model_name == "ViLaSR-7B":
processor = processor_x
model = model_x
elif model_name == "OCRFlux-3B":
processor = processor_t
model = model_t
elif model_name == "ShotVL-7B":
processor = processor_s
model = model_s
else:
yield "Invalid model selected.", "Invalid model selected."
return
if image is None:
yield "Please upload an image.", "Please upload an image."
return
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text},
]
}]
prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[prompt_full],
images=[image],
return_tensors="pt",
padding=True,
truncation=False,
max_length=MAX_INPUT_TOKEN_LENGTH
).to(device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer, buffer
@spaces.GPU
def generate_video(model_name: str, text: str, video_path: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2):
"""
Generate responses using the selected model for video input.
"""
if model_name == "Camel-Doc-OCR-062825":
processor = processor_m
model = model_m
elif model_name == "ViLaSR-7B":
processor = processor_x
model = model_x
elif model_name == "OCRFlux-3B":
processor = processor_t
model = model_t
elif model_name == "ShotVL-7B":
processor = processor_s
model = model_s
else:
yield "Invalid model selected.", "Invalid model selected."
return
if video_path is None:
yield "Please upload a video.", "Please upload a video."
return
frames = downsample_video(video_path)
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "text": text}]}
]
for frame in frames:
image, timestamp = frame
messages[1]["content"].append({"type": "text", "text": f"Frame {timestamp}:"})
messages[1]["content"].append({"type": "image", "image": image})
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
truncation=False,
max_length=MAX_INPUT_TOKEN_LENGTH
).to(device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
buffer = buffer.replace("<|im_end|>", "")
time.sleep(0.01)
yield buffer, buffer
# Define examples for image, video, and object detection inference
image_examples = [
["convert this page to doc [text] precisely for markdown.", "images/1.png"],
["convert this page to doc [table] precisely for markdown.", "images/2.png"],
["explain the movie shot in detail.", "images/3.png"],
["fill the correct numbers.", "images/4.png"]
]
video_examples = [
["explain the ad video in detail.", "videos/1.mp4"],
["explain the video in detail.", "videos/2.mp4"]
]
object_detection_examples = [
["object/1.png", "detect red and yellow cars."],
["object/2.png", "detect the white cat."]
]
# Added CSS to style the output area as a "Canvas"
css = """
.submit-btn {
background-color: #2980b9 !important;
color: white !important;
}
.submit-btn:hover {
background-color: #3498db !important;
}
.canvas-output {
border: 2px solid #4682B4;
border-radius: 10px;
padding: 20px;
}
"""
# Create the Gradio Interface
with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo:
gr.Markdown("# **[Doc VLMs v2 [Localization]](https://huggingface.co/collections/prithivMLmods/multimodal-implementations-67c9982ea04b39f0608badb0)**")
with gr.Row():
with gr.Column():
with gr.Tabs():
with gr.TabItem("Image Inference"):
image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
image_upload = gr.Image(type="pil", label="Image")
image_submit = gr.Button("Submit", elem_classes="submit-btn")
gr.Examples(
examples=image_examples,
inputs=[image_query, image_upload]
)
with gr.TabItem("Video Inference"):
video_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
video_upload = gr.Video(label="Video")
video_submit = gr.Button("Submit", elem_classes="submit-btn")
gr.Examples(
examples=video_examples,
inputs=[video_query, video_upload]
)
with gr.TabItem("Object Detection / Localization"):
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input Image", type="pil")
system_prompt = gr.Textbox(label="System Prompt", value=default_system_prompt, visible=False)
text_input = gr.Textbox(label="Query Input")
submit_btn = gr.Button(value="Submit", elem_classes="submit-btn")
with gr.Column():
model_output_text = gr.Textbox(label="Model Output Text")
parsed_boxes = gr.Textbox(label="Parsed Boxes")
annotated_image = gr.Image(label="Annotated Image")
gr.Examples(
examples=object_detection_examples,
inputs=[input_img, text_input],
outputs=[model_output_text, parsed_boxes, annotated_image],
fn=run_example,
cache_examples=True,
)
submit_btn.click(
fn=run_example,
inputs=[input_img, text_input, system_prompt],
outputs=[model_output_text, parsed_boxes, annotated_image]
)
with gr.Accordion("Advanced options", open=False):
max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
with gr.Column():
with gr.Column(elem_classes="canvas-output"):
gr.Markdown("## Result.Md")
output = gr.Textbox(label="Raw Output Stream", interactive=False, lines=2)
markdown_output = gr.Markdown(label="Formatted Result (Result.Md)")
model_choice = gr.Radio(
choices=["Camel-Doc-OCR-062825", "OCRFlux-3B", "ShotVL-7B", "ViLaSR-7B"],
label="Select Model",
value="Camel-Doc-OCR-062825"
)
gr.Markdown("**Model Info 💻** | [Report Bug](https://huggingface.co/spaces/prithivMLmods/Doc-VLMs-v2-Localization/discussions)")
gr.Markdown("> [Camel-Doc-OCR-062825](https://huggingface.co/prithivMLmods/Camel-Doc-OCR-062825) : camel-doc-ocr-062825 model is a fine-tuned version of qwen2.5-vl-7b-instruct, optimized for document retrieval, content extraction, and analysis recognition. built on top of the qwen2.5-vl architecture, this model enhances document comprehension capabilities.")
gr.Markdown("> [OCRFlux-3B](https://huggingface.co/ChatDOC/OCRFlux-3B) : ocrflux-3b model that's fine-tuned from qwen2.5-vl-3b-instruct using our private document datasets and some data from olmocr-mix-0225 dataset. optimized for document retrieval, content extraction, and analysis recognition. the best way to use this model is via the ocrflux toolkit.")
gr.Markdown("> [ViLaSR](https://huggingface.co/AntResearchNLP/ViLaSR) : vilasr-7b model as presented in reinforcing spatial reasoning in vision-language models with interwoven thinking and visual drawing. efficient reasoning capabilities.")
gr.Markdown("> [ShotVL-7B](https://huggingface.co/Vchitect/ShotVL-7B) : shotvl-7b is a fine-tuned version of qwen2.5-vl-7b-instruct, trained by supervised fine-tuning on the largest and high-quality dataset for cinematic language understanding to date. it currently achieves state-of-the-art performance on shotbench.")
gr.Markdown(">⚠️note: all the models in space are not guaranteed to perform well in video inference use cases.")
image_submit.click(
fn=generate_image,
inputs=[model_choice, image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[output, markdown_output]
)
video_submit.click(
fn=generate_video,
inputs=[model_choice, video_query, video_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
outputs=[output, markdown_output]
)
if __name__ == "__main__":
demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True)