Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
-
# app.py —
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import base64
|
| 5 |
import tempfile
|
| 6 |
from io import BytesIO
|
| 7 |
from urllib.request import urlretrieve
|
| 8 |
-
from typing import List, Tuple, Dict, Any
|
| 9 |
|
| 10 |
import gradio as gr
|
| 11 |
from gradio_pdf import PDF
|
|
@@ -18,7 +19,7 @@ from tqdm import tqdm
|
|
| 18 |
|
| 19 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
| 20 |
|
| 21 |
-
#
|
| 22 |
from openai import OpenAI
|
| 23 |
|
| 24 |
|
|
@@ -26,9 +27,10 @@ from openai import OpenAI
|
|
| 26 |
# Globals & Config
|
| 27 |
# =============================
|
| 28 |
api_key_env = os.getenv("OPENAI_API_KEY", "").strip()
|
|
|
|
| 29 |
ds: List[torch.Tensor] = [] # page embeddings
|
| 30 |
images: List[Image.Image] = [] # PIL images in page order
|
| 31 |
-
current_pdf_path: str
|
| 32 |
|
| 33 |
device_map = (
|
| 34 |
"cuda:0"
|
|
@@ -125,36 +127,13 @@ def index_from_url(url: str) -> Tuple[str, str]:
|
|
| 125 |
return status, local_path
|
| 126 |
|
| 127 |
|
| 128 |
-
def _build_image_parts_from_indices(indices: List[int]) -> List[Dict[str, Any]]:
|
| 129 |
-
"""Turn page indices into OpenAI vision content parts."""
|
| 130 |
-
parts: List[Dict[str, Any]] = []
|
| 131 |
-
seen = sorted({i for i in indices if 0 <= i < len(images)})
|
| 132 |
-
for idx in seen:
|
| 133 |
-
b64 = encode_image_to_base64(images[idx])
|
| 134 |
-
parts.append({
|
| 135 |
-
"type": "input_image",
|
| 136 |
-
"image_url": f"data:image/jpeg;base64,{b64}",
|
| 137 |
-
})
|
| 138 |
-
return parts
|
| 139 |
-
|
| 140 |
# =============================
|
| 141 |
-
#
|
| 142 |
# =============================
|
| 143 |
|
| 144 |
def search(query: str, k: int = 5) -> List[int]:
|
| 145 |
"""
|
| 146 |
Search within an indexed PDF and return ONLY the indices of the most relevant pages (0-based).
|
| 147 |
-
|
| 148 |
-
MCP tool description:
|
| 149 |
-
- name: mcp_test_search
|
| 150 |
-
- description: Search within the indexed PDF for the most relevant pages and return their 0-based indices only.
|
| 151 |
-
- input_schema:
|
| 152 |
-
type: object
|
| 153 |
-
properties:
|
| 154 |
-
query: {type: string, description: "User query in natural language."}
|
| 155 |
-
k: {type: integer, minimum: 1, maximum: 50, default: 5, description: "Number of top pages to retrieve (before neighbor expansion)."}
|
| 156 |
-
required: ["query"]
|
| 157 |
-
|
| 158 |
Returns:
|
| 159 |
List[int]: Sorted unique 0-based indices of pages to inspect (includes neighbor expansion).
|
| 160 |
"""
|
|
@@ -166,16 +145,14 @@ def search(query: str, k: int = 5) -> List[int]:
|
|
| 166 |
k = max(1, min(int(k), len(images)))
|
| 167 |
device = _ensure_model_device()
|
| 168 |
|
| 169 |
-
# Encode query
|
| 170 |
with torch.no_grad():
|
| 171 |
batch_query = processor.process_queries([query]).to(model.device)
|
| 172 |
embeddings_query = model(**batch_query)
|
| 173 |
q_vecs = list(torch.unbind(embeddings_query.to("cpu")))
|
| 174 |
|
| 175 |
-
# Score and select top-k
|
| 176 |
scores = processor.score(q_vecs, ds, device=device)
|
| 177 |
top_k_indices = scores[0].topk(k).indices.tolist()
|
| 178 |
-
print(query, top_k_indices)
|
| 179 |
|
| 180 |
# Neighbor expansion for context
|
| 181 |
base = set(top_k_indices)
|
|
@@ -183,65 +160,91 @@ def search(query: str, k: int = 5) -> List[int]:
|
|
| 183 |
for i in base:
|
| 184 |
expanded.add(i - 1)
|
| 185 |
expanded.add(i + 1)
|
| 186 |
-
expanded = {i for i in expanded if 0 <= i < len(images)}
|
| 187 |
|
| 188 |
return sorted(expanded)
|
| 189 |
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
# =============================
|
| 192 |
-
#
|
| 193 |
# =============================
|
| 194 |
|
| 195 |
SYSTEM = (
|
| 196 |
"""
|
| 197 |
-
You are a PDF research agent
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
2) For each sub-query, call mcp_test_search to get indices; Once you receive the indices to use, print "Received" and stop generating. Images will be injected in your stream.
|
| 204 |
-
3) Continue reasoning using ONLY the provided images. If info is insufficient, iterate: refine sub-queries and call the tools again. You may make further tool calls later in the conversation as needed.
|
| 205 |
-
|
| 206 |
-
Grounding & citations:
|
| 207 |
-
• Use ONLY information visible in the provided page images.
|
| 208 |
-
• After any claim, cite as (p.<page>).
|
| 209 |
• If an answer is not present, say “Not found in the provided pages.”
|
| 210 |
|
| 211 |
-
|
| 212 |
-
•
|
| 213 |
-
• Do not refer to “the above” or “previous messages”.
|
| 214 |
"""
|
| 215 |
).strip()
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
DEFAULT_MCP_SERVER_URL = "https://manu-mcp-test.hf.space/gradio_api/mcp/"
|
| 218 |
DEFAULT_MCP_SERVER_LABEL = "colpali_rag"
|
| 219 |
-
DEFAULT_ALLOWED_TOOLS = "mcp_test_search
|
| 220 |
|
| 221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
def stream_agent(question: str,
|
| 223 |
api_key: str,
|
| 224 |
-
|
| 225 |
server_url: str,
|
| 226 |
server_label: str,
|
| 227 |
require_approval: str,
|
| 228 |
allowed_tools: str):
|
| 229 |
"""
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
| 233 |
"""
|
| 234 |
-
final_text = "Answer:"
|
| 235 |
-
summary_text = "Reasoning:"
|
| 236 |
-
log_lines = ["Log"]
|
| 237 |
-
|
| 238 |
if not api_key:
|
| 239 |
yield "⚠️ **Please provide your OpenAI API key.**", "", ""
|
| 240 |
return
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
client = OpenAI(api_key=api_key)
|
| 243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
prev_response_id: Optional[str] = None
|
|
|
|
|
|
|
| 245 |
tools = [{
|
| 246 |
"type": "mcp",
|
| 247 |
"server_label": server_label or DEFAULT_MCP_SERVER_LABEL,
|
|
@@ -250,20 +253,34 @@ def stream_agent(question: str,
|
|
| 250 |
"require_approval": require_approval or "never",
|
| 251 |
}]
|
| 252 |
|
| 253 |
-
#
|
| 254 |
-
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
def run_round(round_idx: int, attached_indices: List[int]):
|
|
|
|
|
|
|
|
|
|
| 258 |
nonlocal prev_response_id
|
| 259 |
-
assembled_text = ""
|
| 260 |
-
assembled_summary = ""
|
| 261 |
-
# Will hold the most recent indices returned by mcp_test_search in THIS stream
|
| 262 |
-
last_search_indices: List[int] = []
|
| 263 |
|
| 264 |
-
|
| 265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
parts += _build_image_parts_from_indices(attached_indices)
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
# First call includes system; follow-ups use previous_response_id
|
| 269 |
if prev_response_id:
|
|
@@ -279,21 +296,20 @@ def stream_agent(question: str,
|
|
| 279 |
input=req_input,
|
| 280 |
reasoning={"effort": "medium", "summary": "auto"},
|
| 281 |
tools=tools,
|
| 282 |
-
store=True,
|
| 283 |
)
|
| 284 |
if prev_response_id:
|
| 285 |
req_kwargs["previous_response_id"] = prev_response_id
|
| 286 |
|
| 287 |
-
# Helper
|
| 288 |
def _maybe_parse_indices(chunk: str) -> List[int]:
|
| 289 |
import json, re
|
| 290 |
-
# Find the last bracketed JSON array in the chunk
|
| 291 |
arrs = re.findall(r'\[[^\]]*\]', chunk)
|
| 292 |
for s in reversed(arrs):
|
| 293 |
try:
|
| 294 |
val = json.loads(s)
|
| 295 |
if isinstance(val, list) and all(isinstance(x, int) for x in val):
|
| 296 |
-
return sorted({x for x in val if
|
| 297 |
except Exception:
|
| 298 |
pass
|
| 299 |
return []
|
|
@@ -306,74 +322,74 @@ def stream_agent(question: str,
|
|
| 306 |
etype = getattr(event, "type", "")
|
| 307 |
|
| 308 |
if etype == "response.output_text.delta":
|
| 309 |
-
|
| 310 |
-
yield
|
| 311 |
|
| 312 |
elif etype == "response.reasoning_summary_text.delta":
|
| 313 |
-
|
| 314 |
-
yield
|
| 315 |
|
| 316 |
-
#
|
| 317 |
elif etype in ("response.function_call_arguments.delta", "response.tool_call_arguments.delta"):
|
| 318 |
-
|
|
|
|
|
|
|
| 319 |
|
| 320 |
-
#
|
| 321 |
elif etype.startswith("response.tool_result"):
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
if
|
| 325 |
-
tool_result_buffer += str(
|
| 326 |
-
# opportunistic parse so UI can progress early
|
| 327 |
parsed_now = _maybe_parse_indices(tool_result_buffer)
|
| 328 |
if parsed_now:
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
yield assembled_text or " ", assembled_summary or " ", "\n".join(log_lines[-400:])
|
| 333 |
|
| 334 |
-
# Finalize
|
| 335 |
_final = stream.get_final_response()
|
| 336 |
try:
|
| 337 |
prev_response_id = getattr(_final, "id", None)
|
| 338 |
except Exception:
|
| 339 |
prev_response_id = None
|
| 340 |
|
| 341 |
-
#
|
| 342 |
-
|
| 343 |
-
return sorted(set(last_search_indices))
|
| 344 |
-
|
| 345 |
-
# Otherwise, just render whatever text we have
|
| 346 |
-
yield assembled_text or " ", assembled_summary or " ", "\n".join(log_lines[-400:])
|
| 347 |
-
return None
|
| 348 |
|
| 349 |
except Exception as e:
|
| 350 |
log_lines.append(f"[round {round_idx}] stream error: {e}")
|
| 351 |
-
yield f"❌ {e}",
|
| 352 |
-
return
|
| 353 |
|
| 354 |
-
# Controller: iterate rounds
|
| 355 |
max_rounds = 3
|
| 356 |
round_idx = 1
|
|
|
|
|
|
|
| 357 |
while round_idx <= max_rounds:
|
| 358 |
-
|
| 359 |
-
next_indices = None
|
| 360 |
for final_md, summary_md, log_md in run_round(round_idx, pending_indices):
|
| 361 |
yield final_md, summary_md, log_md
|
| 362 |
|
| 363 |
-
# If the model
|
| 364 |
-
|
| 365 |
-
if
|
| 366 |
pending_indices = next_indices
|
| 367 |
-
# Attach those pages in a **new** GPT-5 call using previous_response_id
|
| 368 |
round_idx += 1
|
| 369 |
continue
|
| 370 |
|
| 371 |
-
# No tool
|
| 372 |
break
|
| 373 |
|
| 374 |
return
|
| 375 |
|
| 376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
CUSTOM_CSS = """
|
| 378 |
:root {
|
| 379 |
--bg: #0e1117;
|
|
@@ -433,19 +449,20 @@ body {background: radial-gradient(1200px 600px at 20% -10%, rgba(124,58,237,.25)
|
|
| 433 |
|
| 434 |
def build_ui():
|
| 435 |
theme = gr.themes.Soft()
|
| 436 |
-
with gr.Blocks(title="ColPali PDF RAG +
|
| 437 |
gr.HTML(
|
| 438 |
"""
|
| 439 |
<div class="app-header">
|
| 440 |
<div class="icon">📚</div>
|
| 441 |
<div>
|
| 442 |
-
<h1>ColPali PDF Search + Streaming Agent</h1>
|
| 443 |
-
<p>Index PDFs with ColQwen2
|
| 444 |
</div>
|
| 445 |
</div>
|
| 446 |
"""
|
| 447 |
)
|
| 448 |
|
|
|
|
| 449 |
with gr.Tab("1) Index & Preview"):
|
| 450 |
with gr.Row():
|
| 451 |
with gr.Column(scale=1):
|
|
@@ -482,19 +499,20 @@ def build_ui():
|
|
| 482 |
index_btn.click(handle_upload, inputs=[pdf_input], outputs=[status_box, pdf_view])
|
| 483 |
index_url_btn.click(handle_url, inputs=[url_box], outputs=[status_box, pdf_view])
|
| 484 |
|
|
|
|
| 485 |
with gr.Tab("2) Ask (Direct — returns indices)"):
|
| 486 |
with gr.Row():
|
| 487 |
with gr.Column(scale=1):
|
| 488 |
query_box = gr.Textbox(placeholder="Enter your question…", label="Query", lines=4)
|
| 489 |
k_slider = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results (k)", value=5)
|
| 490 |
search_button = gr.Button("🔍 Search", variant="primary")
|
| 491 |
-
|
| 492 |
-
|
| 493 |
with gr.Column(scale=2):
|
| 494 |
output_text = gr.Textbox(label="Indices (0-based)", lines=12, placeholder="[0, 1, 2, ...]")
|
| 495 |
|
| 496 |
search_button.click(search, inputs=[query_box, k_slider], outputs=[output_text])
|
| 497 |
|
|
|
|
| 498 |
with gr.Tab("3) Agent (Streaming)"):
|
| 499 |
with gr.Row(equal_height=True):
|
| 500 |
with gr.Column(scale=1):
|
|
@@ -522,7 +540,7 @@ def build_ui():
|
|
| 522 |
)
|
| 523 |
with gr.Row():
|
| 524 |
server_url_box = gr.Textbox(
|
| 525 |
-
label="MCP Server URL",
|
| 526 |
value=DEFAULT_MCP_SERVER_URL,
|
| 527 |
)
|
| 528 |
server_label_box = gr.Textbox(
|
|
@@ -550,7 +568,15 @@ def build_ui():
|
|
| 550 |
|
| 551 |
run_btn.click(
|
| 552 |
stream_agent,
|
| 553 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
outputs=[final_md, summary_md, log_md],
|
| 555 |
)
|
| 556 |
|
|
@@ -560,4 +586,5 @@ def build_ui():
|
|
| 560 |
if __name__ == "__main__":
|
| 561 |
demo = build_ui()
|
| 562 |
# mcp_server=True exposes this app's MCP endpoint at /gradio_api/mcp/
|
|
|
|
| 563 |
demo.queue(max_size=5).launch(debug=True, mcp_server=True)
|
|
|
|
| 1 |
+
# app.py — ColPali + MCP (search-only) + GPT-5 follow-up responses
|
| 2 |
+
# Images are injected by the app in new calls; no base64 is passed through MCP.
|
| 3 |
|
| 4 |
import os
|
| 5 |
import base64
|
| 6 |
import tempfile
|
| 7 |
from io import BytesIO
|
| 8 |
from urllib.request import urlretrieve
|
| 9 |
+
from typing import List, Tuple, Dict, Any, Optional
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
from gradio_pdf import PDF
|
|
|
|
| 19 |
|
| 20 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
| 21 |
|
| 22 |
+
# Streaming Responses API
|
| 23 |
from openai import OpenAI
|
| 24 |
|
| 25 |
|
|
|
|
| 27 |
# Globals & Config
|
| 28 |
# =============================
|
| 29 |
api_key_env = os.getenv("OPENAI_API_KEY", "").strip()
|
| 30 |
+
|
| 31 |
ds: List[torch.Tensor] = [] # page embeddings
|
| 32 |
images: List[Image.Image] = [] # PIL images in page order
|
| 33 |
+
current_pdf_path: Optional[str] = None
|
| 34 |
|
| 35 |
device_map = (
|
| 36 |
"cuda:0"
|
|
|
|
| 127 |
return status, local_path
|
| 128 |
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
# =============================
|
| 131 |
+
# Local Search (ColPali)
|
| 132 |
# =============================
|
| 133 |
|
| 134 |
def search(query: str, k: int = 5) -> List[int]:
|
| 135 |
"""
|
| 136 |
Search within an indexed PDF and return ONLY the indices of the most relevant pages (0-based).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
Returns:
|
| 138 |
List[int]: Sorted unique 0-based indices of pages to inspect (includes neighbor expansion).
|
| 139 |
"""
|
|
|
|
| 145 |
k = max(1, min(int(k), len(images)))
|
| 146 |
device = _ensure_model_device()
|
| 147 |
|
|
|
|
| 148 |
with torch.no_grad():
|
| 149 |
batch_query = processor.process_queries([query]).to(model.device)
|
| 150 |
embeddings_query = model(**batch_query)
|
| 151 |
q_vecs = list(torch.unbind(embeddings_query.to("cpu")))
|
| 152 |
|
|
|
|
| 153 |
scores = processor.score(q_vecs, ds, device=device)
|
| 154 |
top_k_indices = scores[0].topk(k).indices.tolist()
|
| 155 |
+
print("[search]", query, top_k_indices)
|
| 156 |
|
| 157 |
# Neighbor expansion for context
|
| 158 |
base = set(top_k_indices)
|
|
|
|
| 160 |
for i in base:
|
| 161 |
expanded.add(i - 1)
|
| 162 |
expanded.add(i + 1)
|
| 163 |
+
expanded = {i for i in expanded if 0 <= i < len(images)}
|
| 164 |
|
| 165 |
return sorted(expanded)
|
| 166 |
|
| 167 |
|
| 168 |
+
def _build_image_parts_from_indices(indices: List[int]) -> List[Dict[str, Any]]:
|
| 169 |
+
"""Turn page indices into OpenAI vision content parts."""
|
| 170 |
+
parts: List[Dict[str, Any]] = []
|
| 171 |
+
seen = sorted({i for i in indices if 0 <= i < len(images)})
|
| 172 |
+
for idx in seen:
|
| 173 |
+
b64 = encode_image_to_base64(images[idx])
|
| 174 |
+
parts.append({
|
| 175 |
+
"type": "input_image",
|
| 176 |
+
"image_url": f"data:image/jpeg;base64,{b64}",
|
| 177 |
+
})
|
| 178 |
+
return parts
|
| 179 |
+
|
| 180 |
+
|
| 181 |
# =============================
|
| 182 |
+
# Agent System Prompt
|
| 183 |
# =============================
|
| 184 |
|
| 185 |
SYSTEM = (
|
| 186 |
"""
|
| 187 |
+
You are a PDF research agent.
|
| 188 |
+
|
| 189 |
+
Workflow:
|
| 190 |
+
• When you need pages, call the tool: mcp_test_search(query: string, k: int).
|
| 191 |
+
• The app will attach the images for the LAST search result you produced in this turn in a follow-up message.
|
| 192 |
+
• Use ONLY the provided images for grounding and cite as (p.<page>).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
• If an answer is not present, say “Not found in the provided pages.”
|
| 194 |
|
| 195 |
+
Deliverable:
|
| 196 |
+
• Return a clear, standalone Markdown answer in the user's language. Include concise tables for lists of dates/items.
|
|
|
|
| 197 |
"""
|
| 198 |
).strip()
|
| 199 |
|
| 200 |
+
|
| 201 |
+
# =============================
|
| 202 |
+
# MCP config (search-only)
|
| 203 |
+
# =============================
|
| 204 |
DEFAULT_MCP_SERVER_URL = "https://manu-mcp-test.hf.space/gradio_api/mcp/"
|
| 205 |
DEFAULT_MCP_SERVER_LABEL = "colpali_rag"
|
| 206 |
+
DEFAULT_ALLOWED_TOOLS = "mcp_test_search" # search-only; no get_pages
|
| 207 |
|
| 208 |
|
| 209 |
+
# =============================
|
| 210 |
+
# Streaming Agent (multi-round with previous_response_id)
|
| 211 |
+
# =============================
|
| 212 |
+
|
| 213 |
def stream_agent(question: str,
|
| 214 |
api_key: str,
|
| 215 |
+
model_name: str,
|
| 216 |
server_url: str,
|
| 217 |
server_label: str,
|
| 218 |
require_approval: str,
|
| 219 |
allowed_tools: str):
|
| 220 |
"""
|
| 221 |
+
Multi-round streaming:
|
| 222 |
+
• Seed: optional local ColPali search on the user question to attach initial pages.
|
| 223 |
+
• Each round: open a GPT-5 stream with *attached images* (if any).
|
| 224 |
+
• If the model calls mcp_test_search and returns indices, we end the stream and
|
| 225 |
+
start a NEW API call with previous_response_id + the requested pages attached.
|
| 226 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
if not api_key:
|
| 228 |
yield "⚠️ **Please provide your OpenAI API key.**", "", ""
|
| 229 |
return
|
| 230 |
|
| 231 |
+
if not images or not ds:
|
| 232 |
+
yield "⚠️ **Index a PDF first in tab 1.**", "", ""
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
client = OpenAI(api_key=api_key)
|
| 236 |
|
| 237 |
+
# Optional seeding: attach some likely pages on round 1
|
| 238 |
+
try:
|
| 239 |
+
seed_indices = search(question, k=5) or []
|
| 240 |
+
except Exception as e:
|
| 241 |
+
yield f"❌ Search failed: {e}", "", ""
|
| 242 |
+
return
|
| 243 |
+
|
| 244 |
+
log_lines = ["Log", f"[seed] indices={seed_indices}"]
|
| 245 |
prev_response_id: Optional[str] = None
|
| 246 |
+
|
| 247 |
+
# MCP tool routing (search-only)
|
| 248 |
tools = [{
|
| 249 |
"type": "mcp",
|
| 250 |
"server_label": server_label or DEFAULT_MCP_SERVER_LABEL,
|
|
|
|
| 253 |
"require_approval": require_approval or "never",
|
| 254 |
}]
|
| 255 |
|
| 256 |
+
# Shared mutable state for each round
|
| 257 |
+
round_state: Dict[str, Any] = {
|
| 258 |
+
"last_search_indices": None,
|
| 259 |
+
"final_text": "",
|
| 260 |
+
"summary_text": "",
|
| 261 |
+
}
|
| 262 |
|
| 263 |
def run_round(round_idx: int, attached_indices: List[int]):
|
| 264 |
+
"""
|
| 265 |
+
Stream one round. If tool results (indices) arrive, store them in round_state["last_search_indices"].
|
| 266 |
+
"""
|
| 267 |
nonlocal prev_response_id
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
+
round_state["last_search_indices"] = None
|
| 270 |
+
round_state["final_text"] = ""
|
| 271 |
+
round_state["summary_text"] = ""
|
| 272 |
+
|
| 273 |
+
# Build the user content for this round
|
| 274 |
+
parts: List[Dict[str, Any]] = []
|
| 275 |
+
if round_idx == 1:
|
| 276 |
+
parts.append({"type": "input_text", "text": question})
|
| 277 |
+
else:
|
| 278 |
+
parts.append({"type": "input_text", "text": "Continue reasoning with the newly attached pages."})
|
| 279 |
+
|
| 280 |
parts += _build_image_parts_from_indices(attached_indices)
|
| 281 |
+
if attached_indices:
|
| 282 |
+
pages_str = ", ".join(str(i + 1) for i in sorted(set(attached_indices)))
|
| 283 |
+
parts.append({"type": "input_text", "text": f"(Attached pages: {pages_str}). Use ONLY these images; cite as (p.X)."})
|
| 284 |
|
| 285 |
# First call includes system; follow-ups use previous_response_id
|
| 286 |
if prev_response_id:
|
|
|
|
| 296 |
input=req_input,
|
| 297 |
reasoning={"effort": "medium", "summary": "auto"},
|
| 298 |
tools=tools,
|
| 299 |
+
store=True, # persist conversation state on server
|
| 300 |
)
|
| 301 |
if prev_response_id:
|
| 302 |
req_kwargs["previous_response_id"] = prev_response_id
|
| 303 |
|
| 304 |
+
# Helper: parse a JSON array of ints from tool result text
|
| 305 |
def _maybe_parse_indices(chunk: str) -> List[int]:
|
| 306 |
import json, re
|
|
|
|
| 307 |
arrs = re.findall(r'\[[^\]]*\]', chunk)
|
| 308 |
for s in reversed(arrs):
|
| 309 |
try:
|
| 310 |
val = json.loads(s)
|
| 311 |
if isinstance(val, list) and all(isinstance(x, int) for x in val):
|
| 312 |
+
return sorted({x for x in val if 0 <= x < len(images)})
|
| 313 |
except Exception:
|
| 314 |
pass
|
| 315 |
return []
|
|
|
|
| 322 |
etype = getattr(event, "type", "")
|
| 323 |
|
| 324 |
if etype == "response.output_text.delta":
|
| 325 |
+
round_state["final_text"] += event.delta
|
| 326 |
+
yield round_state["final_text"] or " ", round_state["summary_text"] or " ", "\n".join(log_lines[-400:])
|
| 327 |
|
| 328 |
elif etype == "response.reasoning_summary_text.delta":
|
| 329 |
+
round_state["summary_text"] += event.delta
|
| 330 |
+
yield round_state["final_text"] or " ", round_state["summary_text"] or " ", "\n".join(log_lines[-400:])
|
| 331 |
|
| 332 |
+
# Log tool call argument deltas (optional)
|
| 333 |
elif etype in ("response.function_call_arguments.delta", "response.tool_call_arguments.delta"):
|
| 334 |
+
delta = getattr(event, "delta", None)
|
| 335 |
+
if delta:
|
| 336 |
+
log_lines.append(str(delta))
|
| 337 |
|
| 338 |
+
# Capture tool RESULT text and try to parse indices
|
| 339 |
elif etype.startswith("response.tool_result"):
|
| 340 |
+
print("here")
|
| 341 |
+
delta_text = getattr(event, "delta", "") or getattr(event, "output_text", "")
|
| 342 |
+
if delta_text:
|
| 343 |
+
tool_result_buffer += str(delta_text)
|
|
|
|
| 344 |
parsed_now = _maybe_parse_indices(tool_result_buffer)
|
| 345 |
if parsed_now:
|
| 346 |
+
round_state["last_search_indices"] = parsed_now
|
| 347 |
+
log_lines.append(f"[tool-result] indices={parsed_now}")
|
| 348 |
+
yield round_state["final_text"] or " ", round_state["summary_text"] or " ", "\n".join(log_lines[-400:])
|
|
|
|
| 349 |
|
| 350 |
+
# Finalize this response; remember ID for follow-ups
|
| 351 |
_final = stream.get_final_response()
|
| 352 |
try:
|
| 353 |
prev_response_id = getattr(_final, "id", None)
|
| 354 |
except Exception:
|
| 355 |
prev_response_id = None
|
| 356 |
|
| 357 |
+
# Emit one last update after stream ends
|
| 358 |
+
yield round_state["final_text"] or " ", round_state["summary_text"] or " ", "\n".join(log_lines[-400:])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
except Exception as e:
|
| 361 |
log_lines.append(f"[round {round_idx}] stream error: {e}")
|
| 362 |
+
yield f"❌ {e}", round_state["summary_text"] or "", "\n".join(log_lines[-400:])
|
| 363 |
+
return
|
| 364 |
|
| 365 |
+
# Controller: iterate rounds; if the model searched, attach those pages next
|
| 366 |
max_rounds = 3
|
| 367 |
round_idx = 1
|
| 368 |
+
pending_indices = list(seed_indices)
|
| 369 |
+
|
| 370 |
while round_idx <= max_rounds:
|
| 371 |
+
print(round_idx, pending_indices)
|
|
|
|
| 372 |
for final_md, summary_md, log_md in run_round(round_idx, pending_indices):
|
| 373 |
yield final_md, summary_md, log_md
|
| 374 |
|
| 375 |
+
# If the model returned indices via the tool, use them in a fresh call
|
| 376 |
+
next_indices = round_state.get("last_search_indices") or []
|
| 377 |
+
if next_indices:
|
| 378 |
pending_indices = next_indices
|
|
|
|
| 379 |
round_idx += 1
|
| 380 |
continue
|
| 381 |
|
| 382 |
+
# No further tool-driven retrieval → done
|
| 383 |
break
|
| 384 |
|
| 385 |
return
|
| 386 |
|
| 387 |
|
| 388 |
+
|
| 389 |
+
# =============================
|
| 390 |
+
# Gradio UI
|
| 391 |
+
# =============================
|
| 392 |
+
|
| 393 |
CUSTOM_CSS = """
|
| 394 |
:root {
|
| 395 |
--bg: #0e1117;
|
|
|
|
| 449 |
|
| 450 |
def build_ui():
|
| 451 |
theme = gr.themes.Soft()
|
| 452 |
+
with gr.Blocks(title="ColPali PDF RAG + Follow-up Responses", theme=theme, css=CUSTOM_CSS) as demo:
|
| 453 |
gr.HTML(
|
| 454 |
"""
|
| 455 |
<div class="app-header">
|
| 456 |
<div class="icon">📚</div>
|
| 457 |
<div>
|
| 458 |
+
<h1>ColPali PDF Search + Streaming Agent (Follow-up Responses)</h1>
|
| 459 |
+
<p>Index PDFs with ColQwen2. The agent attaches images in follow-up GPT-5 calls; MCP is search-only.</p>
|
| 460 |
</div>
|
| 461 |
</div>
|
| 462 |
"""
|
| 463 |
)
|
| 464 |
|
| 465 |
+
# ---- Tab 1: Index & Preview
|
| 466 |
with gr.Tab("1) Index & Preview"):
|
| 467 |
with gr.Row():
|
| 468 |
with gr.Column(scale=1):
|
|
|
|
| 499 |
index_btn.click(handle_upload, inputs=[pdf_input], outputs=[status_box, pdf_view])
|
| 500 |
index_url_btn.click(handle_url, inputs=[url_box], outputs=[status_box, pdf_view])
|
| 501 |
|
| 502 |
+
# ---- Tab 2: Ask (Direct — returns indices)
|
| 503 |
with gr.Tab("2) Ask (Direct — returns indices)"):
|
| 504 |
with gr.Row():
|
| 505 |
with gr.Column(scale=1):
|
| 506 |
query_box = gr.Textbox(placeholder="Enter your question…", label="Query", lines=4)
|
| 507 |
k_slider = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results (k)", value=5)
|
| 508 |
search_button = gr.Button("🔍 Search", variant="primary")
|
| 509 |
+
|
|
|
|
| 510 |
with gr.Column(scale=2):
|
| 511 |
output_text = gr.Textbox(label="Indices (0-based)", lines=12, placeholder="[0, 1, 2, ...]")
|
| 512 |
|
| 513 |
search_button.click(search, inputs=[query_box, k_slider], outputs=[output_text])
|
| 514 |
|
| 515 |
+
# ---- Tab 3: Agent (Streaming)
|
| 516 |
with gr.Tab("3) Agent (Streaming)"):
|
| 517 |
with gr.Row(equal_height=True):
|
| 518 |
with gr.Column(scale=1):
|
|
|
|
| 540 |
)
|
| 541 |
with gr.Row():
|
| 542 |
server_url_box = gr.Textbox(
|
| 543 |
+
label="MCP Server URL (search-only)",
|
| 544 |
value=DEFAULT_MCP_SERVER_URL,
|
| 545 |
)
|
| 546 |
server_label_box = gr.Textbox(
|
|
|
|
| 568 |
|
| 569 |
run_btn.click(
|
| 570 |
stream_agent,
|
| 571 |
+
inputs=[
|
| 572 |
+
question,
|
| 573 |
+
api_key_box,
|
| 574 |
+
model_box,
|
| 575 |
+
server_url_box,
|
| 576 |
+
server_label_box,
|
| 577 |
+
require_approval_box,
|
| 578 |
+
allowed_tools_box,
|
| 579 |
+
],
|
| 580 |
outputs=[final_md, summary_md, log_md],
|
| 581 |
)
|
| 582 |
|
|
|
|
| 586 |
if __name__ == "__main__":
|
| 587 |
demo = build_ui()
|
| 588 |
# mcp_server=True exposes this app's MCP endpoint at /gradio_api/mcp/
|
| 589 |
+
# We keep the MCP server available, but the agent never uses MCP to pass images.
|
| 590 |
demo.queue(max_size=5).launch(debug=True, mcp_server=True)
|