Spaces:
Configuration error
Configuration error
oremaz
commited on
Commit
·
9da53da
1
Parent(s):
9bba232
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -21,6 +21,7 @@ from llama_index.core.query_engine import RetrieverQueryEngine
|
|
| 21 |
from llama_index.core.retrievers import VectorIndexRetriever
|
| 22 |
from llama_index.core.tools import FunctionTool
|
| 23 |
from llama_index.core.workflow import Context
|
|
|
|
| 24 |
|
| 25 |
# LlamaIndex specialized imports
|
| 26 |
from llama_index.callbacks.wandb import WandbCallbackHandler
|
|
@@ -40,7 +41,6 @@ from llama_index.readers.file import (
|
|
| 40 |
DocxReader,
|
| 41 |
CSVReader,
|
| 42 |
PandasExcelReader,
|
| 43 |
-
ImageReader,
|
| 44 |
)
|
| 45 |
from typing import List, Union
|
| 46 |
from llama_index.core import VectorStoreIndex, Document, Settings
|
|
@@ -50,6 +50,8 @@ from llama_index.core.postprocessor import SentenceTransformerRerank
|
|
| 50 |
from llama_index.core.query_engine import RetrieverQueryEngine
|
| 51 |
from llama_index.core.query_pipeline import QueryPipeline
|
| 52 |
|
|
|
|
|
|
|
| 53 |
|
| 54 |
wandb_callback = WandbCallbackHandler(run_args={"project": "gaia-llamaindex-agents"})
|
| 55 |
llama_debug = LlamaDebugHandler(print_trace_on_end=True)
|
|
@@ -113,68 +115,66 @@ Settings.callback_manager = callback_manager
|
|
| 113 |
|
| 114 |
def read_and_parse_content(input_path: str) -> List[Document]:
|
| 115 |
"""
|
| 116 |
-
Reads and parses content from a file path
|
|
|
|
| 117 |
"""
|
| 118 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
readers_map = {
|
| 120 |
-
# Documents
|
| 121 |
'.pdf': PDFReader(),
|
| 122 |
'.docx': DocxReader(),
|
| 123 |
'.doc': DocxReader(),
|
| 124 |
-
# Data files
|
| 125 |
'.csv': CSVReader(),
|
| 126 |
'.json': JSONReader(),
|
| 127 |
'.xlsx': PandasExcelReader(),
|
| 128 |
-
# Audio files - traitement spécial
|
| 129 |
-
# '.mp3': sera géré séparément
|
| 130 |
}
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
documents
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
documents = loader.load_data(urls=[input_path])
|
| 140 |
-
|
| 141 |
-
# --- Local File Handling ---
|
| 142 |
-
else:
|
| 143 |
-
if not os.path.exists(input_path):
|
| 144 |
-
return [Document(text=f"Error: File not found at {input_path}")]
|
| 145 |
-
|
| 146 |
-
file_extension = os.path.splitext(input_path)[1].lower()
|
| 147 |
-
|
| 148 |
-
if file_extension in ['.mp3', '.mp4', '.wav', '.m4a', '.flac']:
|
| 149 |
-
try:
|
| 150 |
-
loader = AssemblyAIAudioTranscriptReader(file_path=input_path)
|
| 151 |
-
documents = loader.load_data()
|
| 152 |
-
return documents
|
| 153 |
-
except Exception as e:
|
| 154 |
-
return [Document(text=f"Error transcribing audio: {e}")]
|
| 155 |
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
return [Document(
|
| 158 |
-
text=f"
|
| 159 |
-
metadata={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
)]
|
|
|
|
|
|
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
#
|
| 175 |
for doc in documents:
|
| 176 |
doc.metadata["source"] = input_path
|
| 177 |
-
|
| 178 |
return documents
|
| 179 |
|
| 180 |
# --- Create the final LlamaIndex Tool from the completed function ---
|
|
@@ -184,150 +184,183 @@ extract_url_tool = FunctionTool.from_defaults(
|
|
| 184 |
description="Searches web and returns a relevant URL based on a query"
|
| 185 |
)
|
| 186 |
|
| 187 |
-
|
| 188 |
-
"""
|
| 189 |
-
Creates a RAG query engine tool from documents with advanced indexing and querying capabilities.
|
| 190 |
-
|
| 191 |
-
This function implements a sophisticated RAG pipeline using hierarchical or sentence-window parsing
|
| 192 |
-
depending on document count, vector indexing, and reranking for optimal information retrieval.
|
| 193 |
-
|
| 194 |
-
Args:
|
| 195 |
-
documents (List[Document]): A list of LlamaIndex Document objects from read_and_parse_tool.
|
| 196 |
-
Must not be empty to create a valid RAG engine.
|
| 197 |
-
query (str, optional): If provided, immediately queries the created RAG engine and returns
|
| 198 |
-
the answer as a string. If None, returns the QueryEngineTool for later use.
|
| 199 |
-
Defaults to None.
|
| 200 |
-
|
| 201 |
-
Returns:
|
| 202 |
-
Union[QueryEngineTool, str]:
|
| 203 |
-
- QueryEngineTool: When query=None, returns a tool configured for agent use with
|
| 204 |
-
advanced reranking and similarity search capabilities.
|
| 205 |
-
- str: When query is provided, returns the direct answer from the RAG engine.
|
| 206 |
-
- None: When documents list is empty.
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
# --- 1. Node Parsing (from your 'create_advanced_index' logic) ---
|
| 219 |
-
# Using the exact parsers and logic you defined.
|
| 220 |
-
hierarchical_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=[2048, 512, 128])
|
| 221 |
-
sentence_window_parser = SentenceWindowNodeParser.from_defaults(
|
| 222 |
-
window_size=3,
|
| 223 |
-
window_metadata_key="window",
|
| 224 |
-
original_text_metadata_key="original_text",
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
# Choose parser based on document count
|
| 228 |
-
if len(documents) > 5: # Heuristic for using hierarchical parser
|
| 229 |
-
nodes = hierarchical_parser.get_nodes_from_documents(documents)
|
| 230 |
-
else:
|
| 231 |
-
nodes = sentence_window_parser.get_nodes_from_documents(documents)
|
| 232 |
-
|
| 233 |
-
# --- 2. Index Creation ---
|
| 234 |
-
# Assumes Settings.embed_model is configured globally as in your snippet
|
| 235 |
-
index = VectorStoreIndex(nodes)
|
| 236 |
-
|
| 237 |
-
# --- 3. Query Engine Creation (from your 'create_context_aware_query_engine' logic) ---
|
| 238 |
-
# Using the exact reranker you specified
|
| 239 |
-
reranker = SentenceTransformerRerank(
|
| 240 |
-
model="cross-encoder/ms-marco-MiniLM-L-2-v2",
|
| 241 |
-
top_n=5
|
| 242 |
-
)
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
|
|
|
| 249 |
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
)
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
if query :
|
| 262 |
-
result = rag_engine_tool.query_engine.query(query)
|
| 263 |
-
return str(result)
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
docs.append(read_and_parse_content(path))
|
| 271 |
-
return create_rag_tool_fn(docs,query)
|
| 272 |
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
)
|
| 278 |
|
| 279 |
# 1. Create the base DuckDuckGo search tool from the official spec.
|
| 280 |
# This tool returns text summaries of search results, not just URLs.
|
| 281 |
base_duckduckgo_tool = DuckDuckGoSearchToolSpec().to_tool_list()[1]
|
| 282 |
|
| 283 |
-
|
| 284 |
-
def
|
| 285 |
"""
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
Args:
|
| 289 |
-
query: The natural language search query.
|
| 290 |
-
Returns:
|
| 291 |
-
A string containing the first URL found, or an error message if none is found.
|
| 292 |
"""
|
| 293 |
-
#
|
| 294 |
-
search_results = base_duckduckgo_tool(query, max_results
|
| 295 |
-
print(search_results)
|
| 296 |
-
|
| 297 |
-
# Use a regular expression to find the first URL in the text output
|
| 298 |
-
# The \S+ pattern matches any sequence of non-whitespace characters
|
| 299 |
url_match = re.search(r"https?://\S+", str(search_results))
|
| 300 |
|
| 301 |
-
if url_match:
|
| 302 |
-
return
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
verbose=True,
|
| 316 |
-
callback_manager=callback_manager,
|
| 317 |
-
)
|
| 318 |
|
| 319 |
-
# 3. Create the final, customized FunctionTool for the agent.
|
| 320 |
-
# This is the tool you will actually give to your agent.
|
| 321 |
-
extract_url_tool = FunctionTool.from_defaults(
|
| 322 |
-
fn=search_and_extract_top_url,
|
| 323 |
-
name="extract_url_tool",
|
| 324 |
-
description=(
|
| 325 |
-
"Use this tool when you need to find a relevant URL to answer a question. It takes a search query as input and returns a single, relevant URL."
|
| 326 |
-
)
|
| 327 |
-
)
|
| 328 |
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
def safe_import(module_name):
|
| 333 |
"""Safely import a module, return None if not available"""
|
|
@@ -422,17 +455,6 @@ code_execution_tool = FunctionTool.from_defaults(
|
|
| 422 |
description="Executes Python code safely for calculations and data processing"
|
| 423 |
)
|
| 424 |
|
| 425 |
-
code_agent = ReActAgent(
|
| 426 |
-
name="code_agent",
|
| 427 |
-
description="Handles Python code for calculations and data processing",
|
| 428 |
-
system_prompt="You are a Python programming specialist. You work with Python code to perform calculations, data analysis, and mathematical operations.",
|
| 429 |
-
tools=[code_execution_tool],
|
| 430 |
-
llm=code_llm,
|
| 431 |
-
max_steps=6,
|
| 432 |
-
verbose=True,
|
| 433 |
-
callback_manager=callback_manager,
|
| 434 |
-
)
|
| 435 |
-
|
| 436 |
def clean_response(response: str) -> str:
|
| 437 |
"""Clean response by removing common prefixes"""
|
| 438 |
response_clean = response.strip()
|
|
@@ -528,9 +550,43 @@ class EnhancedGAIAAgent:
|
|
| 528 |
if not hf_token:
|
| 529 |
print("Warning: HUGGINGFACEHUB_API_TOKEN not found, some features may not work")
|
| 530 |
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
|
| 535 |
def download_gaia_file(self, task_id: str, api_url: str = "https://agents-course-unit4-scoring.hf.space") -> str:
|
| 536 |
"""Download file associated with task_id"""
|
|
@@ -546,32 +602,56 @@ class EnhancedGAIAAgent:
|
|
| 546 |
print(f"Failed to download file for task {task_id}: {e}")
|
| 547 |
return None
|
| 548 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
async def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
|
| 550 |
"""
|
| 551 |
-
Solve GAIA question with
|
| 552 |
"""
|
| 553 |
question = question_data.get("Question", "")
|
| 554 |
task_id = question_data.get("task_id", "")
|
| 555 |
|
| 556 |
-
# Try to download file if task_id provided
|
| 557 |
file_path = None
|
| 558 |
if task_id:
|
| 559 |
try:
|
| 560 |
file_path = self.download_gaia_file(task_id)
|
| 561 |
if file_path:
|
| 562 |
-
documents
|
|
|
|
|
|
|
| 563 |
except Exception as e:
|
| 564 |
print(f"Failed to download/process file for task {task_id}: {e}")
|
| 565 |
|
| 566 |
-
#
|
| 567 |
context_prompt = f"""
|
| 568 |
GAIA Task ID: {task_id}
|
| 569 |
Question: {question}
|
| 570 |
-
{f'File
|
| 571 |
-
|
|
|
|
|
|
|
| 572 |
try:
|
| 573 |
ctx = Context(self.coordinator)
|
| 574 |
print("=== AGENT REASONING STEPS ===")
|
|
|
|
| 575 |
|
| 576 |
handler = self.coordinator.run(ctx=ctx, user_msg=context_prompt)
|
| 577 |
|
|
@@ -588,9 +668,18 @@ You are a general AI assistant. I will ask you a question. Report your thoughts,
|
|
| 588 |
final_answer = str(final_response).strip()
|
| 589 |
|
| 590 |
print(f"Final GAIA formatted answer: {final_answer}")
|
|
|
|
|
|
|
| 591 |
return final_answer
|
| 592 |
|
| 593 |
except Exception as e:
|
| 594 |
error_msg = f"Error processing question: {str(e)}"
|
| 595 |
print(error_msg)
|
| 596 |
-
return error_msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
from llama_index.core.retrievers import VectorIndexRetriever
|
| 22 |
from llama_index.core.tools import FunctionTool
|
| 23 |
from llama_index.core.workflow import Context
|
| 24 |
+
from llama_index.postprocessor.colpali_rerank import ColPaliRerank
|
| 25 |
|
| 26 |
# LlamaIndex specialized imports
|
| 27 |
from llama_index.callbacks.wandb import WandbCallbackHandler
|
|
|
|
| 41 |
DocxReader,
|
| 42 |
CSVReader,
|
| 43 |
PandasExcelReader,
|
|
|
|
| 44 |
)
|
| 45 |
from typing import List, Union
|
| 46 |
from llama_index.core import VectorStoreIndex, Document, Settings
|
|
|
|
| 50 |
from llama_index.core.query_engine import RetrieverQueryEngine
|
| 51 |
from llama_index.core.query_pipeline import QueryPipeline
|
| 52 |
|
| 53 |
+
import importlib.util
|
| 54 |
+
import sys
|
| 55 |
|
| 56 |
wandb_callback = WandbCallbackHandler(run_args={"project": "gaia-llamaindex-agents"})
|
| 57 |
llama_debug = LlamaDebugHandler(print_trace_on_end=True)
|
|
|
|
| 115 |
|
| 116 |
def read_and_parse_content(input_path: str) -> List[Document]:
|
| 117 |
"""
|
| 118 |
+
Reads and parses content from a local file path into Document objects.
|
| 119 |
+
URL handling has been moved to search_and_extract_top_url.
|
| 120 |
"""
|
| 121 |
+
# Remove URL handling - this will now only handle local files
|
| 122 |
+
if not os.path.exists(input_path):
|
| 123 |
+
return [Document(text=f"Error: File not found at {input_path}")]
|
| 124 |
+
|
| 125 |
+
file_extension = os.path.splitext(input_path)[1].lower()
|
| 126 |
+
|
| 127 |
+
# Readers map
|
| 128 |
readers_map = {
|
|
|
|
| 129 |
'.pdf': PDFReader(),
|
| 130 |
'.docx': DocxReader(),
|
| 131 |
'.doc': DocxReader(),
|
|
|
|
| 132 |
'.csv': CSVReader(),
|
| 133 |
'.json': JSONReader(),
|
| 134 |
'.xlsx': PandasExcelReader(),
|
|
|
|
|
|
|
| 135 |
}
|
| 136 |
|
| 137 |
+
if file_extension in ['.mp3', '.mp4', '.wav', '.m4a', '.flac']:
|
| 138 |
+
try:
|
| 139 |
+
loader = AssemblyAIAudioTranscriptReader(file_path=input_path)
|
| 140 |
+
documents = loader.load_data()
|
| 141 |
+
return documents
|
| 142 |
+
except Exception as e:
|
| 143 |
+
return [Document(text=f"Error transcribing audio: {e}")]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
if file_extension in ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']:
|
| 146 |
+
# Load the actual image content, not just the path
|
| 147 |
+
try:
|
| 148 |
+
with open(input_path, 'rb') as f:
|
| 149 |
+
image_data = f.read()
|
| 150 |
return [Document(
|
| 151 |
+
text=f"IMAGE_CONTENT_BINARY",
|
| 152 |
+
metadata={
|
| 153 |
+
"source": input_path,
|
| 154 |
+
"type": "image",
|
| 155 |
+
"path": input_path,
|
| 156 |
+
"image_data": image_data # Store actual image data
|
| 157 |
+
}
|
| 158 |
)]
|
| 159 |
+
except Exception as e:
|
| 160 |
+
return [Document(text=f"Error reading image: {e}")]
|
| 161 |
|
| 162 |
+
if file_extension in readers_map:
|
| 163 |
+
loader = readers_map[file_extension]
|
| 164 |
+
documents = loader.load_data(file=input_path)
|
| 165 |
+
else:
|
| 166 |
+
# Fallback for text files
|
| 167 |
+
try:
|
| 168 |
+
with open(input_path, 'r', encoding='utf-8') as f:
|
| 169 |
+
content = f.read()
|
| 170 |
+
documents = [Document(text=content, metadata={"source": input_path})]
|
| 171 |
+
except Exception as e:
|
| 172 |
+
return [Document(text=f"Error reading file as plain text: {e}")]
|
| 173 |
+
|
| 174 |
+
# Add source metadata
|
| 175 |
for doc in documents:
|
| 176 |
doc.metadata["source"] = input_path
|
| 177 |
+
|
| 178 |
return documents
|
| 179 |
|
| 180 |
# --- Create the final LlamaIndex Tool from the completed function ---
|
|
|
|
| 184 |
description="Searches web and returns a relevant URL based on a query"
|
| 185 |
)
|
| 186 |
|
| 187 |
+
class DynamicQueryEngineManager:
|
| 188 |
+
"""Single unified manager for all RAG operations - replaces the entire static approach."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
+
def __init__(self, initial_documents: List[str] = None):
|
| 191 |
+
self.documents = []
|
| 192 |
+
self.query_engine_tool = None
|
| 193 |
|
| 194 |
+
# Load initial documents if provided
|
| 195 |
+
if initial_documents:
|
| 196 |
+
self._load_initial_documents(initial_documents)
|
| 197 |
+
|
| 198 |
+
self._create_rag_tool()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
+
def _load_initial_documents(self, document_paths: List[str]):
|
| 201 |
+
"""Load initial documents using read_and_parse_content."""
|
| 202 |
+
for path in document_paths:
|
| 203 |
+
docs = read_and_parse_content(path)
|
| 204 |
+
self.documents.extend(docs)
|
| 205 |
+
print(f"Loaded {len(self.documents)} initial documents")
|
| 206 |
|
| 207 |
+
def _create_rag_tool(self):
|
| 208 |
+
"""Create RAG tool using your sophisticated logic."""
|
| 209 |
+
documents = self.documents if self.documents else [
|
| 210 |
+
Document(text="No documents loaded yet. Use web search to add content.")
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
# Your exact sophisticated RAG logic from create_rag_tool_fn
|
| 214 |
+
hierarchical_parser = HierarchicalNodeParser.from_defaults(chunk_sizes=[2048, 512, 128])
|
| 215 |
+
sentence_window_parser = SentenceWindowNodeParser.from_defaults(
|
| 216 |
+
window_size=3,
|
| 217 |
+
window_metadata_key="window",
|
| 218 |
+
original_text_metadata_key="original_text",
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if len(documents) > 5:
|
| 222 |
+
nodes = hierarchical_parser.get_nodes_from_documents(documents)
|
| 223 |
+
else:
|
| 224 |
+
nodes = sentence_window_parser.get_nodes_from_documents(documents)
|
| 225 |
+
|
| 226 |
+
index = VectorStoreIndex(nodes)
|
| 227 |
+
|
| 228 |
+
# Your HybridReranker class (exact same implementation)
|
| 229 |
+
class HybridReranker:
|
| 230 |
+
def __init__(self):
|
| 231 |
+
self.text_reranker = SentenceTransformerRerank(
|
| 232 |
+
model="cross-encoder/ms-marco-MiniLM-L-2-v2",
|
| 233 |
+
top_n=3
|
| 234 |
+
)
|
| 235 |
+
self.visual_reranker = ColPaliRerank(
|
| 236 |
+
top_n=3,
|
| 237 |
+
model_name="vidore/colpali-v1.2",
|
| 238 |
+
device="cuda"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def postprocess_nodes(self, nodes, query_bundle):
|
| 242 |
+
# Your exact implementation
|
| 243 |
+
text_nodes = []
|
| 244 |
+
visual_nodes = []
|
| 245 |
+
|
| 246 |
+
for node in nodes:
|
| 247 |
+
if (hasattr(node, 'image_path') and node.image_path) or \
|
| 248 |
+
(hasattr(node, 'metadata') and node.metadata.get('file_type') in ['jpg', 'png', 'jpeg', 'pdf']) or \
|
| 249 |
+
(hasattr(node, 'metadata') and node.metadata.get('type') in ['image', 'web_image']):
|
| 250 |
+
visual_nodes.append(node)
|
| 251 |
+
else:
|
| 252 |
+
text_nodes.append(node)
|
| 253 |
+
|
| 254 |
+
reranked_text = []
|
| 255 |
+
reranked_visual = []
|
| 256 |
+
|
| 257 |
+
if text_nodes:
|
| 258 |
+
reranked_text = self.text_reranker.postprocess_nodes(text_nodes, query_bundle)
|
| 259 |
+
|
| 260 |
+
if visual_nodes:
|
| 261 |
+
reranked_visual = self.visual_reranker.postprocess_nodes(visual_nodes, query_bundle)
|
| 262 |
+
|
| 263 |
+
combined_results = []
|
| 264 |
+
max_len = max(len(reranked_text), len(reranked_visual))
|
| 265 |
+
|
| 266 |
+
for i in range(max_len):
|
| 267 |
+
if i < len(reranked_text):
|
| 268 |
+
combined_results.append(reranked_text[i])
|
| 269 |
+
if i < len(reranked_visual):
|
| 270 |
+
combined_results.append(reranked_visual[i])
|
| 271 |
+
|
| 272 |
+
return combined_results[:5]
|
| 273 |
+
|
| 274 |
+
hybrid_reranker = HybridReranker()
|
| 275 |
+
|
| 276 |
+
query_engine = index.as_query_engine(
|
| 277 |
+
similarity_top_k=10,
|
| 278 |
+
node_postprocessors=[hybrid_reranker],
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
self.query_engine_tool = QueryEngineTool.from_defaults(
|
| 282 |
+
query_engine=query_engine,
|
| 283 |
+
name="dynamic_hybrid_multimodal_rag_tool",
|
| 284 |
+
description=(
|
| 285 |
+
"Advanced dynamic knowledge base with hybrid reranking. "
|
| 286 |
+
"Uses ColPali for visual content and SentenceTransformer for text content. "
|
| 287 |
+
"Automatically updated with web search content."
|
| 288 |
+
)
|
| 289 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
+
def add_documents(self, new_documents: List[Document]):
|
| 292 |
+
"""Add documents from web search and recreate tool."""
|
| 293 |
+
self.documents.extend(new_documents)
|
| 294 |
+
self._create_rag_tool() # Recreate with ALL documents
|
| 295 |
+
print(f"Added {len(new_documents)} documents. Total: {len(self.documents)}")
|
|
|
|
|
|
|
| 296 |
|
| 297 |
+
def get_tool(self):
|
| 298 |
+
return self.query_engine_tool
|
| 299 |
+
|
| 300 |
+
# Global instance
|
| 301 |
+
dynamic_qe_manager = DynamicQueryEngineManager()
|
| 302 |
|
| 303 |
# 1. Create the base DuckDuckGo search tool from the official spec.
|
| 304 |
# This tool returns text summaries of search results, not just URLs.
|
| 305 |
base_duckduckgo_tool = DuckDuckGoSearchToolSpec().to_tool_list()[1]
|
| 306 |
|
| 307 |
+
|
| 308 |
+
def search_and_extract_content_from_url(query: str) -> List[Document]:
|
| 309 |
"""
|
| 310 |
+
Searches web, gets top URL, and extracts both text content and images.
|
| 311 |
+
Returns a list of Document objects containing the extracted content.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
"""
|
| 313 |
+
# Get URL from search
|
| 314 |
+
search_results = base_duckduckgo_tool(query, max_results=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
url_match = re.search(r"https?://\S+", str(search_results))
|
| 316 |
|
| 317 |
+
if not url_match:
|
| 318 |
+
return [Document(text="No URL could be extracted from the search results.")]
|
| 319 |
+
|
| 320 |
+
url = url_match.group(0)[:-2]
|
| 321 |
+
documents = []
|
| 322 |
+
|
| 323 |
+
try:
|
| 324 |
+
# Check if it's a YouTube URL
|
| 325 |
+
if "youtube" in urlparse(url).netloc:
|
| 326 |
+
loader = YoutubeTranscriptReader()
|
| 327 |
+
documents = loader.load_data(youtubelinks=[url])
|
| 328 |
+
else:
|
| 329 |
+
loader = TrafilaturaWebReader (include_images = True)
|
| 330 |
+
documents = loader.load_data(youtubelinks=[url])
|
|
|
|
|
|
|
|
|
|
| 331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
|
| 333 |
+
def enhanced_web_search_and_update(query: str) -> str:
|
| 334 |
+
"""
|
| 335 |
+
Performs web search, extracts content, and adds it to the dynamic query engine.
|
| 336 |
+
"""
|
| 337 |
+
# Extract content from web search
|
| 338 |
+
documents = search_and_extract_content_from_url(query)
|
| 339 |
+
|
| 340 |
+
# Add documents to the dynamic query engine
|
| 341 |
+
if documents and not any("Error" in doc.text for doc in documents):
|
| 342 |
+
dynamic_qe_manager.add_documents(documents)
|
| 343 |
+
|
| 344 |
+
# Return summary of what was added
|
| 345 |
+
text_docs = [doc for doc in documents if doc.metadata.get("type") == "web_text"]
|
| 346 |
+
image_docs = [doc for doc in documents if doc.metadata.get("type") == "web_image"]
|
| 347 |
+
|
| 348 |
+
summary = f"Successfully added web content to knowledge base:\n"
|
| 349 |
+
summary += f"- {len(text_docs)} text documents\n"
|
| 350 |
+
summary += f"- {len(image_docs)} images\n"
|
| 351 |
+
summary += f"Source: {documents[0].metadata.get('source', 'Unknown')}"
|
| 352 |
+
|
| 353 |
+
return summary
|
| 354 |
+
else:
|
| 355 |
+
error_msg = documents[0].text if documents else "No content extracted"
|
| 356 |
+
return f"Failed to extract web content: {error_msg}"
|
| 357 |
+
|
| 358 |
+
# Create the enhanced web search tool
|
| 359 |
+
enhanced_web_search_tool = FunctionTool.from_defaults(
|
| 360 |
+
fn=enhanced_web_search_and_update,
|
| 361 |
+
name="enhanced_web_search",
|
| 362 |
+
description="Search the web, extract content and images, and add them to the knowledge base for future queries."
|
| 363 |
+
)
|
| 364 |
|
| 365 |
def safe_import(module_name):
|
| 366 |
"""Safely import a module, return None if not available"""
|
|
|
|
| 455 |
description="Executes Python code safely for calculations and data processing"
|
| 456 |
)
|
| 457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
def clean_response(response: str) -> str:
|
| 459 |
"""Clean response by removing common prefixes"""
|
| 460 |
response_clean = response.strip()
|
|
|
|
| 550 |
if not hf_token:
|
| 551 |
print("Warning: HUGGINGFACEHUB_API_TOKEN not found, some features may not work")
|
| 552 |
|
| 553 |
+
# Initialize the dynamic query engine manager
|
| 554 |
+
self.dynamic_qe_manager = DynamicQueryEngineManager()
|
| 555 |
+
|
| 556 |
+
# Create enhanced agents with dynamic tools
|
| 557 |
+
self.external_knowledge_agent = ReActAgent(
|
| 558 |
+
name="external_knowledge_agent",
|
| 559 |
+
description="Advanced information retrieval with dynamic knowledge base",
|
| 560 |
+
system_prompt="""You are an advanced information specialist with a sophisticated RAG system.
|
| 561 |
+
Your knowledge base uses hybrid reranking and grows dynamically with each web search and document addition.
|
| 562 |
+
Always add relevant content to your knowledge base, then query it for answers.""",
|
| 563 |
+
tools=[
|
| 564 |
+
enhanced_web_search_tool,
|
| 565 |
+
self.dynamic_qe_manager.get_tool(),
|
| 566 |
+
code_execution_tool
|
| 567 |
+
],
|
| 568 |
+
llm=proj_llm,
|
| 569 |
+
max_steps=8,
|
| 570 |
+
verbose=True,
|
| 571 |
+
callback_manager=callback_manager,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
self.code_agent = ReActAgent(
|
| 575 |
+
name="code_agent",
|
| 576 |
+
description="Handles Python code for calculations and data processing",
|
| 577 |
+
system_prompt="You are a Python programming specialist. You work with Python code to perform calculations, data analysis, and mathematical operations.",
|
| 578 |
+
tools=[code_execution_tool],
|
| 579 |
+
llm=code_llm,
|
| 580 |
+
max_steps=6,
|
| 581 |
+
verbose=True,
|
| 582 |
+
callback_manager=callback_manager,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# Fixed indentation: coordinator initialization inside __init__
|
| 586 |
+
self.coordinator = AgentWorkflow(
|
| 587 |
+
agents=[self.external_knowledge_agent, self.code_agent],
|
| 588 |
+
root_agent="external_knowledge_agent"
|
| 589 |
+
)
|
| 590 |
|
| 591 |
def download_gaia_file(self, task_id: str, api_url: str = "https://agents-course-unit4-scoring.hf.space") -> str:
|
| 592 |
"""Download file associated with task_id"""
|
|
|
|
| 602 |
print(f"Failed to download file for task {task_id}: {e}")
|
| 603 |
return None
|
| 604 |
|
| 605 |
+
def add_documents_to_knowledge_base(self, file_path: str):
|
| 606 |
+
"""Add downloaded GAIA documents to the dynamic knowledge base"""
|
| 607 |
+
try:
|
| 608 |
+
documents = read_and_parse_content(file_path)
|
| 609 |
+
if documents:
|
| 610 |
+
self.dynamic_qe_manager.add_documents(documents)
|
| 611 |
+
print(f"Added {len(documents)} documents from {file_path} to dynamic knowledge base")
|
| 612 |
+
|
| 613 |
+
# Update the agent's tools with the refreshed query engine
|
| 614 |
+
self.external_knowledge_agent.tools = [
|
| 615 |
+
enhanced_web_search_tool,
|
| 616 |
+
self.dynamic_qe_manager.get_tool(), # Get the updated tool
|
| 617 |
+
code_execution_tool
|
| 618 |
+
]
|
| 619 |
+
return True
|
| 620 |
+
except Exception as e:
|
| 621 |
+
print(f"Failed to add documents from {file_path}: {e}")
|
| 622 |
+
return False
|
| 623 |
+
|
| 624 |
async def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
|
| 625 |
"""
|
| 626 |
+
Solve GAIA question with dynamic knowledge base integration
|
| 627 |
"""
|
| 628 |
question = question_data.get("Question", "")
|
| 629 |
task_id = question_data.get("task_id", "")
|
| 630 |
|
| 631 |
+
# Try to download and add file to knowledge base if task_id provided
|
| 632 |
file_path = None
|
| 633 |
if task_id:
|
| 634 |
try:
|
| 635 |
file_path = self.download_gaia_file(task_id)
|
| 636 |
if file_path:
|
| 637 |
+
# Add documents to dynamic knowledge base
|
| 638 |
+
self.add_documents_to_knowledge_base(file_path)
|
| 639 |
+
print(f"Successfully integrated GAIA file into dynamic knowledge base")
|
| 640 |
except Exception as e:
|
| 641 |
print(f"Failed to download/process file for task {task_id}: {e}")
|
| 642 |
|
| 643 |
+
# Enhanced context prompt with dynamic knowledge base awareness
|
| 644 |
context_prompt = f"""
|
| 645 |
GAIA Task ID: {task_id}
|
| 646 |
Question: {question}
|
| 647 |
+
{f'File processed and added to knowledge base: {file_path}' if file_path else 'No additional files'}
|
| 648 |
+
|
| 649 |
+
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
|
| 650 |
+
|
| 651 |
try:
|
| 652 |
ctx = Context(self.coordinator)
|
| 653 |
print("=== AGENT REASONING STEPS ===")
|
| 654 |
+
print(f"Dynamic knowledge base contains {len(self.dynamic_qe_manager.documents)} documents")
|
| 655 |
|
| 656 |
handler = self.coordinator.run(ctx=ctx, user_msg=context_prompt)
|
| 657 |
|
|
|
|
| 668 |
final_answer = str(final_response).strip()
|
| 669 |
|
| 670 |
print(f"Final GAIA formatted answer: {final_answer}")
|
| 671 |
+
print(f"Knowledge base now contains {len(self.dynamic_qe_manager.documents)} documents")
|
| 672 |
+
|
| 673 |
return final_answer
|
| 674 |
|
| 675 |
except Exception as e:
|
| 676 |
error_msg = f"Error processing question: {str(e)}"
|
| 677 |
print(error_msg)
|
| 678 |
+
return error_msg
|
| 679 |
+
|
| 680 |
+
def get_knowledge_base_stats(self):
|
| 681 |
+
"""Get statistics about the current knowledge base"""
|
| 682 |
+
return {
|
| 683 |
+
"total_documents": len(self.dynamic_qe_manager.documents),
|
| 684 |
+
"document_sources": [doc.metadata.get("source", "Unknown") for doc in self.dynamic_qe_manager.documents]
|
| 685 |
+
}
|