Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor | |
from qwen_vl_utils import process_vision_info | |
import torch | |
import os | |
import json | |
from pydantic import BaseModel | |
from typing import Tuple | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
model = Qwen2VLForConditionalGeneration.from_pretrained( | |
"Qwen/Qwen2-VL-7B-Instruct", | |
torch_dtype=torch.bfloat16, | |
# attn_implementation="flash_attention_2", | |
device_map="auto", | |
) | |
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
class GeneralRetrievalQuery(BaseModel): | |
broad_topical_query: str | |
broad_topical_explanation: str | |
specific_detail_query: str | |
specific_detail_explanation: str | |
visual_element_query: str | |
visual_element_explanation: str | |
def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]: | |
if prompt_name != "general": | |
raise ValueError("Only 'general' prompt is available in this version") | |
prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus. | |
Please generate 3 different types of retrieval queries: | |
1. A broad topical query: This should cover the main subject of the document. | |
2. A specific detail query: This should focus on a particular fact, figure, or point made in the document. | |
3. A visual element query: This should reference a chart, graph, image, or other visual component in the document, if present. | |
Important guidelines: | |
- Ensure the queries are relevant for retrieval tasks, not just describing the page content. | |
- Frame the queries as if someone is searching for this document, not asking questions about its content. | |
- Make the queries diverse and representative of different search strategies. | |
For each query, also provide a brief explanation of why this query would be effective in retrieving this document. | |
Format your response as a JSON object with the following structure: | |
{ | |
"broad_topical_query": "Your query here", | |
"broad_topical_explanation": "Brief explanation", | |
"specific_detail_query": "Your query here", | |
"specific_detail_explanation": "Brief explanation", | |
"visual_element_query": "Your query here", | |
"visual_element_explanation": "Brief explanation" | |
} | |
If there are no relevant visual elements, replace the third query with another specific detail query. | |
Here is the document image to analyze: | |
<image> | |
Generate the queries based on this image and provide the response in the specified JSON format.""" | |
return prompt, GeneralRetrievalQuery | |
# defined like this so we can later add more prompting options | |
prompt, pydantic_model = get_retrieval_prompt("general") | |
def _prep_data_for_input(image): | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image", | |
"image": image, | |
}, | |
{"type": "text", "text": prompt}, | |
], | |
} | |
] | |
text = processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
) | |
return inputs | |
def generate_response(image): | |
inputs = _prep_data_for_input(image) | |
inputs = inputs.to("cuda") | |
generated_ids = model.generate(**inputs, max_new_tokens=200) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids) :] | |
for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
output_text = processor.batch_decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False, | |
) | |
try: | |
data = json.loads(output_text[0]) | |
return data | |
except Exception: | |
return {} | |
demo = gr.Interface(fn=generate_response, inputs=gr.Image(type="pil"), outputs="json") | |
demo.launch() | |