Spaces:
Sleeping
Sleeping
# app.py — Unified ColPali + MCP Agent (indices-only search, agent receives images) | |
import os | |
import base64 | |
import tempfile | |
from io import BytesIO | |
from urllib.request import urlretrieve | |
from typing import List, Tuple, Dict, Any | |
import gradio as gr | |
from gradio_pdf import PDF | |
import torch | |
from pdf2image import convert_from_path | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from colpali_engine.models import ColQwen2, ColQwen2Processor | |
# Optional (used by the streaming agent) | |
from openai import OpenAI | |
# ============================= | |
# Globals & Config | |
# ============================= | |
api_key_env = os.getenv("OPENAI_API_KEY", "").strip() | |
ds: List[torch.Tensor] = [] # page embeddings | |
images: List[Image.Image] = [] # PIL images in page order | |
current_pdf_path: str | None = None | |
device_map = ( | |
"cuda:0" | |
if torch.cuda.is_available() | |
else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu") | |
) | |
# ============================= | |
# Load Model & Processor | |
# ============================= | |
model = ColQwen2.from_pretrained( | |
"vidore/colqwen2-v1.0", | |
torch_dtype=torch.bfloat16, | |
device_map=device_map, | |
attn_implementation="flash_attention_2", | |
).eval() | |
processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0") | |
# ============================= | |
# Utilities | |
# ============================= | |
def _ensure_model_device() -> str: | |
dev = ( | |
"cuda:0" | |
if torch.cuda.is_available() | |
else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu") | |
) | |
if str(model.device) != dev: | |
model.to(dev) | |
return dev | |
def encode_image_to_base64(image: Image.Image) -> str: | |
"""Encodes a PIL image to base64 (JPEG).""" | |
buffered = BytesIO() | |
image.save(buffered, format="JPEG") | |
return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
# ============================= | |
# Indexing Helpers | |
# ============================= | |
def convert_files(pdf_path: str) -> List[Image.Image]: | |
"""Convert a single PDF path into a list of PIL Images (pages).""" | |
imgs = convert_from_path(pdf_path, thread_count=4) | |
if len(imgs) >= 800: | |
raise gr.Error("The number of images in the dataset should be less than 800.") | |
return imgs | |
def index_gpu(imgs: List[Image.Image]) -> str: | |
"""Embed a list of images (pages) with ColQwen2 (ColPali) and store in globals.""" | |
global ds, images | |
device = _ensure_model_device() | |
# reset previous dataset | |
ds = [] | |
images = imgs | |
dataloader = DataLoader( | |
images, | |
batch_size=4, | |
shuffle=False, | |
collate_fn=lambda x: processor.process_images(x).to(model.device), | |
) | |
for batch_doc in tqdm(dataloader, desc="Indexing pages"): | |
with torch.no_grad(): | |
batch_doc = {k: v.to(device) for k, v in batch_doc.items()} | |
embeddings_doc = model(**batch_doc) | |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
return f"Indexed {len(images)} pages successfully." | |
def index_from_path(pdf_path: str) -> str: | |
imgs = convert_files(pdf_path) | |
return index_gpu(imgs) | |
def index_from_url(url: str) -> Tuple[str, str]: | |
""" | |
Download a PDF from URL and index it. | |
Returns: (status_message, saved_pdf_path) | |
""" | |
tmp_dir = tempfile.mkdtemp(prefix="colpali_") | |
local_path = os.path.join(tmp_dir, "document.pdf") | |
urlretrieve(url, local_path) | |
status = index_from_path(local_path) | |
return status, local_path | |
def _build_image_parts_from_indices(indices: List[int]) -> List[Dict[str, Any]]: | |
"""Turn page indices into OpenAI vision content parts.""" | |
parts: List[Dict[str, Any]] = [] | |
seen = sorted({i for i in indices if 0 <= i < len(images)}) | |
for idx in seen: | |
b64 = encode_image_to_base64(images[idx]) | |
parts.append({ | |
"type": "input_image", | |
"image_url": f"data:image/jpeg;base64,{b64}", | |
}) | |
return parts | |
# ============================= | |
# MCP Tools | |
# ============================= | |
def search(query: str, k: int = 5) -> List[int]: | |
""" | |
Search within an indexed PDF and return ONLY the indices of the most relevant pages (0-based). | |
MCP tool description: | |
- name: mcp_test_search | |
- description: Search within the indexed PDF for the most relevant pages and return their 0-based indices only. | |
- input_schema: | |
type: object | |
properties: | |
query: {type: string, description: "User query in natural language."} | |
k: {type: integer, minimum: 1, maximum: 50, default: 5, description: "Number of top pages to retrieve (before neighbor expansion)."} | |
required: ["query"] | |
Returns: | |
List[int]: Sorted unique 0-based indices of pages to inspect (includes neighbor expansion). | |
""" | |
global ds, images | |
if not images or not ds: | |
return [] | |
k = max(1, min(int(k), len(images))) | |
device = _ensure_model_device() | |
# Encode query | |
with torch.no_grad(): | |
batch_query = processor.process_queries([query]).to(model.device) | |
embeddings_query = model(**batch_query) | |
q_vecs = list(torch.unbind(embeddings_query.to("cpu"))) | |
# Score and select top-k | |
scores = processor.score(q_vecs, ds, device=device) | |
top_k_indices = scores[0].topk(k).indices.tolist() | |
print(query, top_k_indices) | |
# Neighbor expansion for context | |
base = set(top_k_indices) | |
expanded = set(base) | |
for i in base: | |
expanded.add(i - 1) | |
expanded.add(i + 1) | |
expanded = {i for i in expanded if 0 <= i < len(images)} # strict bounds | |
return sorted(expanded) | |
# ============================= | |
# Gradio UI — Unified App | |
# ============================= | |
SYSTEM = ( | |
""" | |
You are a PDF research agent with two tools: | |
• mcp_test_search(query: string, k: int) → returns ONLY 0-based page indices. | |
• mcp_test_get_pages(indices: int[]) → returns the actual page images (as base64 images) for vision. | |
Policy & procedure: | |
1) Break the user task into 1–4 targeted sub-queries (in English). | |
2) For each sub-query, call mcp_test_search to get indices; Once you receive the indices to use, print "Received" and stop generating. Images will be injected in your stream. | |
3) Continue reasoning using ONLY the provided images. If info is insufficient, iterate: refine sub-queries and call the tools again. You may make further tool calls later in the conversation as needed. | |
Grounding & citations: | |
• Use ONLY information visible in the provided page images. | |
• After any claim, cite as (p.<page>). | |
• If an answer is not present, say “Not found in the provided pages.” | |
Final deliverable: | |
• Write a clear, standalone Markdown answer in the user's language. For lists of dates/items, include a concise table. | |
• Do not refer to “the above” or “previous messages”. | |
""" | |
).strip() | |
DEFAULT_MCP_SERVER_URL = "https://manu-mcp-test.hf.space/gradio_api/mcp/" | |
DEFAULT_MCP_SERVER_LABEL = "colpali_rag" | |
DEFAULT_ALLOWED_TOOLS = "mcp_test_search,mcp_test_get_pages" | |
def stream_agent(question: str, | |
api_key: str, | |
model: str, | |
server_url: str, | |
server_label: str, | |
require_approval: str, | |
allowed_tools: str): | |
""" | |
Streaming generator for the agent. | |
NOTE: We rely on OpenAI's MCP tool routing. The mcp_test_search tool returns indices only; | |
the agent is instructed to call mcp_get_pages next to receive images and continue reasoning. | |
""" | |
final_text = "Answer:" | |
summary_text = "Reasoning:" | |
log_lines = ["Log"] | |
if not api_key: | |
yield "⚠️ **Please provide your OpenAI API key.**", "", "" | |
return | |
client = OpenAI(api_key=api_key) | |
prev_response_id: Optional[str] = None | |
tools = [{ | |
"type": "mcp", | |
"server_label": server_label or DEFAULT_MCP_SERVER_LABEL, | |
"server_url": server_url or DEFAULT_MCP_SERVER_URL, | |
"allowed_tools": [t.strip() for t in (allowed_tools or DEFAULT_ALLOWED_TOOLS).split(",") if t.strip()], | |
"require_approval": require_approval or "never", | |
}] | |
# seed pages once (optional) | |
seed_indices = search(question, k=5) or [] | |
pending_indices = list(seed_indices) | |
def run_round(round_idx: int, attached_indices: List[int]): | |
nonlocal prev_response_id | |
assembled_text = "" | |
assembled_summary = "" | |
# Will hold the most recent indices returned by mcp_test_search in THIS stream | |
last_search_indices: List[int] = [] | |
# Build user parts (attach any seed pages we already have) | |
parts: List[Dict[str, Any]] = [{"type": "input_text", "text": question if round_idx == 1 else "Continue with new pages."}] | |
parts += _build_image_parts_from_indices(attached_indices) | |
# First call includes system; follow-ups use previous_response_id | |
if prev_response_id: | |
req_input = [{"role": "user", "content": parts}] | |
else: | |
req_input = [ | |
{"role": "system", "content": SYSTEM}, | |
{"role": "user", "content": parts}, | |
] | |
req_kwargs = dict( | |
model=model_name, | |
input=req_input, | |
reasoning={"effort": "medium", "summary": "auto"}, | |
tools=tools, | |
store=True, | |
) | |
if prev_response_id: | |
req_kwargs["previous_response_id"] = prev_response_id | |
# Helper to try extracting a JSON int array from tool result text | |
def _maybe_parse_indices(chunk: str) -> List[int]: | |
import json, re | |
# Find the last bracketed JSON array in the chunk | |
arrs = re.findall(r'\[[^\]]*\]', chunk) | |
for s in reversed(arrs): | |
try: | |
val = json.loads(s) | |
if isinstance(val, list) and all(isinstance(x, int) for x in val): | |
return sorted({x for x in val if isinstance(x, int)}) | |
except Exception: | |
pass | |
return [] | |
tool_result_buffer = "" # accumulate tool result deltas | |
try: | |
with client.responses.stream(**req_kwargs) as stream: | |
for event in stream: | |
etype = getattr(event, "type", "") | |
if etype == "response.output_text.delta": | |
assembled_text += event.delta | |
yield assembled_text or " ", assembled_summary or " ", "\n".join(log_lines[-400:]) | |
elif etype == "response.reasoning_summary_text.delta": | |
assembled_summary += event.delta | |
yield assembled_text or " ", assembled_summary or " ", "\n".join(log_lines[-400:]) | |
# Capture tool *arguments* in the log for transparency (optional) | |
elif etype in ("response.function_call_arguments.delta", "response.tool_call_arguments.delta"): | |
log_lines.append(str(event.delta)) | |
# ⬇️ NEW: capture tool *results* (indices JSON) from MCP | |
elif etype.startswith("response.tool_result"): | |
# Different SDKs expose .delta or .output_text; handle both | |
delta = getattr(event, "delta", "") or getattr(event, "output_text", "") | |
if delta: | |
tool_result_buffer += str(delta) | |
# opportunistic parse so UI can progress early | |
parsed_now = _maybe_parse_indices(tool_result_buffer) | |
if parsed_now: | |
print(parsed_now) | |
last_search_indices = parsed_now | |
log_lines.append(f"[tool-result] indices={last_search_indices}") | |
yield assembled_text or " ", assembled_summary or " ", "\n".join(log_lines[-400:]) | |
# Finalize, remember response id for follow-ups | |
_final = stream.get_final_response() | |
try: | |
prev_response_id = getattr(_final, "id", None) | |
except Exception: | |
prev_response_id = None | |
# If the model produced search results this round, hand them back to the controller | |
if last_search_indices: | |
return sorted(set(last_search_indices)) | |
# Otherwise, just render whatever text we have | |
yield assembled_text or " ", assembled_summary or " ", "\n".join(log_lines[-400:]) | |
return None | |
except Exception as e: | |
log_lines.append(f"[round {round_idx}] stream error: {e}") | |
yield f"❌ {e}", assembled_summary or "", "\n".join(log_lines[-400:]) | |
return None | |
# Controller: iterate rounds until model stops searching | |
max_rounds = 3 | |
round_idx = 1 | |
while round_idx <= max_rounds: | |
# Start a round with any pending images we already have | |
next_indices = None | |
for final_md, summary_md, log_md in run_round(round_idx, pending_indices): | |
yield final_md, summary_md, log_md | |
# If the model called mcp_test_search, we got indices back; fetch those pages next. | |
# (We ignore pending_indices now—move to the model-chosen ones.) | |
if isinstance(next_indices, list) and next_indices: | |
pending_indices = next_indices | |
# Attach those pages in a **new** GPT-5 call using previous_response_id | |
round_idx += 1 | |
continue | |
# No tool search results this round → we’re done | |
break | |
return | |
CUSTOM_CSS = """ | |
:root { | |
--bg: #0e1117; | |
--panel: #111827; | |
--accent: #7c3aed; | |
--accent-2: #06b6d4; | |
--text: #e5e7eb; | |
--muted: #9ca3af; | |
--border: #1f2937; | |
} | |
.gradio-container {max-width: 1180px !important; margin: 0 auto !important;} | |
body {background: radial-gradient(1200px 600px at 20% -10%, rgba(124,58,237,.25), transparent 60%), | |
radial-gradient(1000px 500px at 120% 10%, rgba(6,182,212,.2), transparent 60%), | |
var(--bg) !important;} | |
.app-header { | |
display:flex; gap:16px; align-items:center; padding:20px 18px; margin:8px 0 12px; | |
border:1px solid var(--border); border-radius:20px; | |
background: linear-gradient(180deg, rgba(255,255,255,.02), rgba(255,255,255,.01)); | |
box-shadow: 0 10px 30px rgba(0,0,0,.25), inset 0 1px 0 rgba(255,255,255,.05); | |
} | |
.app-header .icon { | |
width:48px; height:48px; display:grid; place-items:center; border-radius:14px; | |
background: linear-gradient(135deg, var(--accent), var(--accent-2)); | |
color:white; font-size:26px; | |
} | |
.app-header h1 {font-size:22px; margin:0; color:var(--text); letter-spacing:.2px;} | |
.app-header p {margin:2px 0 0; color:var(--muted); font-size:14px;} | |
.card { | |
border:1px solid var(--border); border-radius:18px; padding:14px 16px; | |
background: linear-gradient(180deg, rgba(255,255,255,.02), rgba(255,255,255,.01)); | |
box-shadow: 0 12px 28px rgba(0,0,0,.18), inset 0 1px 0 rgba(255,255,255,.04); | |
} | |
.gr-button-primary {border-radius:12px !important; font-weight:600;} | |
.gradio-container .tabs {border-radius:16px; overflow:hidden; border:1px solid var(--border);} | |
.markdown-wrap {min-height: 260px;} | |
.summary-wrap {min-height: 180px;} | |
.gr-markdown, .gr-prose { color: var(--text) !important; } | |
.gr-markdown h1, .gr-markdown h2, .gr-markdown h3 {color: #f3f4f6;} | |
.gr-markdown a {color: var(--accent-2); text-decoration: none;} | |
.gr-markdown a:hover {text-decoration: underline;} | |
.gr-markdown table {width: 100%; border-collapse: collapse; margin: 10px 0 16px;} | |
.gr-markdown th, .gr-markdown td {border: 1px solid var(--border); padding: 8px 10px;} | |
.gr-markdown th {background: rgba(255,255,255,.03);} | |
.gr-markdown pre, .gr-markdown code { background: #0b1220; color: #eaeaf0; border-radius: 12px; border: 1px solid #172036; } | |
.gr-markdown pre {padding: 12px 14px; overflow:auto;} | |
.gr-markdown blockquote { border-left: 4px solid var(--accent); padding: 6px 12px; margin: 8px 0; color: #d1d5db; background: rgba(124,58,237,.06); border-radius: 8px; } | |
.log-box { font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; white-space: pre-wrap; color: #d1d5db; background:#0b1220; border:1px solid #172036; border-radius:14px; padding:12px; max-height:280px; overflow:auto; } | |
""" | |
def build_ui(): | |
theme = gr.themes.Soft() | |
with gr.Blocks(title="ColPali PDF RAG + MCP Agent (Indices-only)", theme=theme, css=CUSTOM_CSS) as demo: | |
gr.HTML( | |
""" | |
<div class="app-header"> | |
<div class="icon">📚</div> | |
<div> | |
<h1>ColPali PDF Search + Streaming Agent</h1> | |
<p>Index PDFs with ColQwen2 (ColPali). The search tool returns page indices only; the agent fetches images and reasons visually.</p> | |
</div> | |
</div> | |
""" | |
) | |
with gr.Tab("1) Index & Preview"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
index_btn = gr.Button("📥 Index Uploaded PDF", variant="secondary") | |
url_box = gr.Textbox( | |
label="Or index from URL", | |
placeholder="https://example.com/file.pdf", | |
value="", | |
) | |
index_url_btn = gr.Button("🌐 Load From URL", variant="secondary") | |
status_box = gr.Textbox(label="Status", interactive=False) | |
with gr.Column(scale=2): | |
pdf_view = PDF(label="PDF Preview") | |
# wiring | |
def handle_upload(file): | |
global current_pdf_path | |
if file is None: | |
return "Please upload a PDF.", None | |
path = getattr(file, "name", file) | |
status = index_from_path(path) | |
current_pdf_path = path | |
return status, path | |
def handle_url(url: str): | |
global current_pdf_path | |
if not url or not url.lower().endswith(".pdf"): | |
return "Please provide a direct PDF URL ending in .pdf", None | |
status, path = index_from_url(url) | |
current_pdf_path = path | |
return status, path | |
index_btn.click(handle_upload, inputs=[pdf_input], outputs=[status_box, pdf_view]) | |
index_url_btn.click(handle_url, inputs=[url_box], outputs=[status_box, pdf_view]) | |
with gr.Tab("2) Ask (Direct — returns indices)"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
query_box = gr.Textbox(placeholder="Enter your question…", label="Query", lines=4) | |
k_slider = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results (k)", value=5) | |
search_button = gr.Button("🔍 Search", variant="primary") | |
get_pages_button = gr.Button("🔍 Get Pages", variant="primary") | |
with gr.Column(scale=2): | |
output_text = gr.Textbox(label="Indices (0-based)", lines=12, placeholder="[0, 1, 2, ...]") | |
search_button.click(search, inputs=[query_box, k_slider], outputs=[output_text]) | |
with gr.Tab("3) Agent (Streaming)"): | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
question = gr.Textbox( | |
label="Your question", | |
placeholder="Enter your question…", | |
lines=8, | |
elem_classes=["card"], | |
) | |
run_btn = gr.Button("Run", variant="primary") | |
with gr.Accordion("Connection & Model", open=False, elem_classes=["card"]): | |
with gr.Row(): | |
api_key_box = gr.Textbox( | |
label="OpenAI API Key", | |
placeholder="sk-...", | |
type="password", | |
value=api_key_env, | |
) | |
model_box = gr.Dropdown( | |
label="Model", | |
choices=["gpt-5", "gpt-4.1", "gpt-4o"], | |
value="gpt-5", | |
) | |
with gr.Row(): | |
server_url_box = gr.Textbox( | |
label="MCP Server URL", | |
value=DEFAULT_MCP_SERVER_URL, | |
) | |
server_label_box = gr.Textbox( | |
label="MCP Server Label", | |
value=DEFAULT_MCP_SERVER_LABEL, | |
) | |
with gr.Row(): | |
allowed_tools_box = gr.Textbox( | |
label="Allowed Tools (comma-separated)", | |
value=DEFAULT_ALLOWED_TOOLS, | |
) | |
require_approval_box = gr.Dropdown( | |
label="Require Approval", | |
choices=["never", "auto", "always"], | |
value="never", | |
) | |
with gr.Column(scale=3): | |
with gr.Tab("Answer (Markdown)"): | |
final_md = gr.Markdown(value="", elem_classes=["card", "markdown-wrap"]) | |
with gr.Tab("Live Summary (Markdown)"): | |
summary_md = gr.Markdown(value="", elem_classes=["card", "summary-wrap"]) | |
with gr.Tab("Event Log"): | |
log_md = gr.Markdown(value="", elem_classes=["card", "log-box"]) | |
run_btn.click( | |
stream_agent, | |
inputs=[question, api_key_box, model_box, server_url_box, server_label_box, require_approval_box, allowed_tools_box], | |
outputs=[final_md, summary_md, log_md], | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = build_ui() | |
# mcp_server=True exposes this app's MCP endpoint at /gradio_api/mcp/ | |
demo.queue(max_size=5).launch(debug=True, mcp_server=True) | |