import os import base64 import tempfile from io import BytesIO from urllib.request import urlretrieve import gradio as gr from gradio_pdf import PDF import torch from pdf2image import convert_from_path from PIL import Image from torch.utils.data import DataLoader from tqdm import tqdm from colpali_engine.models import ColQwen2, ColQwen2Processor # ----------------------------- # Globals # ----------------------------- api_key = os.getenv("OPENAI_API_KEY", "") # <- use env var ds = [] # list of document embeddings (torch tensors) images = [] # list of PIL images (page-order) current_pdf_path = None # last (indexed) pdf path for preview # ----------------------------- # Model & processor # ----------------------------- device_map = "cuda:0" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu") model = ColQwen2.from_pretrained( "vidore/colqwen2-v1.0", torch_dtype=torch.bfloat16, device_map=device_map, attn_implementation="flash_attention_2" ).eval() processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0") # ----------------------------- # Utilities # ----------------------------- def encode_image_to_base64(image: Image.Image) -> str: """Encodes a PIL image to a base64 string.""" buffered = BytesIO() image.save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def query_gpt(query: str, retrieved_images: list[tuple[Image.Image, str]]) -> str: """Calls OpenAI's GPT model with the query and image data.""" if api_key and api_key.startswith("sk"): try: from openai import OpenAI base64_images = [encode_image_to_base64(im_caption[0]) for im_caption in retrieved_images] client = OpenAI(api_key=api_key.strip()) PROMPT = """ You are a smart assistant designed to answer questions about a PDF document. You are given relevant information in the form of PDF pages. Use them to construct a short response to the question, and cite your sources (page numbers, etc). If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents. Give detailed and extensive answers, only containing info in the pages you are given. You can answer using information contained in plots and figures if necessary. Answer in the same language as the query. Query: {query} PDF pages: """.strip() response = client.responses.create( model="gpt-5-mini", input=[ { "role": "user", "content": ( [{"type": "input_text", "text": PROMPT.format(query=query)}] + [{"type": "input_image", "image_url": f"data:image/jpeg;base64,{im}"} for im in base64_images] ) } ], # max_tokens=500, ) return response.output_text except Exception as e: print(e) return "OpenAI API connection failure. Verify that OPENAI_API_KEY is set and valid (sk-***)." return "Set OPENAI_API_KEY in your environment to get a custom response." def _ensure_model_device(): dev = "cuda:0" if torch.cuda.is_available() else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu") if str(model.device) != dev: model.to(dev) return dev # ----------------------------- # Indexing helpers # ----------------------------- def convert_files(pdf_path: str) -> list[Image.Image]: """Convert a single PDF path into a list of PIL Images (pages).""" imgs = convert_from_path(pdf_path, thread_count=4) if len(imgs) >= 800: raise gr.Error("The number of images in the dataset should be less than 800.") return imgs def index_gpu(imgs: list[Image.Image]) -> str: """Embed a list of images (pages) with ColPali and store in globals.""" global ds, images device = _ensure_model_device() # reset previous dataset ds = [] images = imgs dataloader = DataLoader( images, batch_size=4, shuffle=False, collate_fn=lambda x: processor.process_images(x).to(model.device), ) for batch_doc in tqdm(dataloader, desc="Indexing pages"): with torch.no_grad(): batch_doc = {k: v.to(device) for k, v in batch_doc.items()} embeddings_doc = model(**batch_doc) ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) return f"Indexed {len(images)} pages successfully." def index_from_path(pdf_path: str) -> str: """Public: index a local PDF file path.""" imgs = convert_files(pdf_path) return index_gpu(imgs) def index_from_url(url: str) -> tuple[str, str]: """ Download a PDF from URL and index it. Returns: status message, saved pdf path """ tmp_dir = tempfile.mkdtemp(prefix="colpali_") local_path = os.path.join(tmp_dir, "document.pdf") urlretrieve(url, local_path) status = index_from_path(local_path) return status, local_path # ----------------------------- # Search (MCP tool-friendly) # ----------------------------- def search(query: str, k: int = 5): """ Search within a PDF document for the most relevant pages to answer a query and synthetizes a short grounded answer using only those pages. MCP tool description: - name: mcp_test_search - description: Search within a PDF document for the most relevant pages to answer a query and synthetizes a short grounded answer using only those pages. - input_schema: type: object properties: query: {type: string, description: "User query in natural language."} k: {type: integer, minimum: 1, maximum: 10, default: 5. description: "Number of top pages to retrieve."} required: ["query"] Args: query (str): Natural-language question to search for. k (int): Number of top results to return (1–10). Returns: ai_response (str): Text answer to the query grounded in content from the PDF, with citations (page numbers). """ global ds, images if not images or not ds: return [], "No document indexed yet. Upload a PDF or load the sample, then run Search." k = max(1, min(int(k), len(images))) device = _ensure_model_device() print(query) # Encode query qs = [] with torch.no_grad(): batch_query = processor.process_queries([query]).to(model.device) embeddings_query = model(**batch_query) qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) # Score and select top-k scores = processor.score(qs, ds, device=device) top_k_indices = scores[0].topk(k).indices.tolist() # Base set & neighbor expansion base = set(top_k_indices) expanded = set(base) for i in base: expanded.add(i - 1) expanded.add(i + 1) expanded = {i for i in expanded if i >= 0 and i<=len(images)} expanded_indices = sorted(expanded) print(top_k_indices, expanded_indices) # Build gallery results with 1-based page numbering results = [] for idx in expanded_indices: page_num = idx + 1 results.append((images[idx], f"Page {page_num}")) # Generate grounded response ai_response = query_gpt(query, results) print(ai_response) return ai_response # ----------------------------- # Gradio UI callbacks # ----------------------------- def handle_upload(file) -> tuple[str, str | None]: """Index a user-uploaded PDF file.""" global current_pdf_path if file is None: return "Please upload a PDF.", None path = getattr(file, "name", file) status = index_from_path(path) current_pdf_path = path return status, path def handle_url(url: str) -> tuple[str, str | None]: """Index a PDF from URL (e.g., a sample).""" global current_pdf_path if not url or not url.lower().endswith(".pdf"): return "Please provide a direct PDF URL.", None status, path = index_from_url(url) current_pdf_path = path return status, path print("Uploading") print(handle_url("https://ecss.nl/wp-content/uploads/2025/05/ECSS-E-ST-40C-Rev.1(30April2025).pdf")) # ----------------------------- # Gradio App # ----------------------------- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚") gr.Markdown( """Demo to test ColQwen2 (ColPali) on PDF documents. ColPali is implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).""" ) with gr.Row(): # with gr.Column(scale=2): # gr.Markdown("## 1️⃣ Load a PDF") # pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"]) # index_btn = gr.Button("📥 Index Uploaded PDF", variant="secondary") # url_box = gr.Textbox( # label="Or index from URL", # placeholder="https://example.com/file.pdf", # value="https://sist.sathyabama.ac.in/sist_coursematerial/uploads/SAR1614.pdf", # ) # index_url_btn = gr.Button("🌐 Load Sample / From URL", variant="secondary") # status_box = gr.Textbox(label="Status", interactive=False) # pdf_view = PDF(label="PDF Preview") with gr.Column(scale=3): gr.Markdown("## 2️⃣ Search") query = gr.Textbox(placeholder="Enter your query here", label="Query") k_slider = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5) search_button = gr.Button("🔍 Search", variant="primary") output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents") # Wiring # index_btn.click(handle_upload, inputs=[pdf_input], outputs=[status_box, pdf_view]) # index_url_btn.click(handle_url, inputs=[url_box], outputs=[status_box, pdf_view]) search_button.click(search, inputs=[query, k_slider], outputs=[output_text]) if __name__ == "__main__": # Optional: pre-load the default sample at startup. # Comment these two lines if you prefer a "cold" start. # msg, path = index_from_url("https://sist.sathyabama.ac.in/sist_coursematerial/uploads/SAR1614.pdf") # print(msg, "->", path) demo.queue(max_size=5).launch(debug=True, mcp_server=True)