import asyncio import importlib import logging import os import time import uuid # for generating thread IDs for checkpointer from typing import AsyncIterator, Optional, TypedDict import litellm import yaml from dotenv import find_dotenv, load_dotenv from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from smolagents import CodeAgent, LiteLLMModel from smolagents.memory import ActionStep, FinalAnswerStep from smolagents.monitoring import LogLevel from agents import create_data_analysis_agent, create_media_agent, create_web_agent from prompts import MANAGER_SYSTEM_PROMPT from tools import perform_calculation, web_search from utils import extract_final_answer litellm._turn_on_debug() # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # Load environment variables load_dotenv(find_dotenv()) # Get required environment variables with validation API_BASE = os.getenv("API_BASE") API_KEY = os.getenv("API_KEY") MODEL_ID = os.getenv("MODEL_ID") if not all([API_BASE, API_KEY, MODEL_ID]): raise ValueError( "Missing required environment variables: API_BASE, API_KEY, MODEL_ID" ) # Define the state types for our graph class AgentState(TypedDict): task: str current_step: Optional[dict] # Store serializable dict instead of ActionStep error: Optional[str] answer_text: Optional[str] # Initialize model with error handling try: model = LiteLLMModel( api_base=API_BASE, api_key=API_KEY, model_id=MODEL_ID, ) except Exception as e: logger.error(f"Failed to initialize model: {str(e)}") raise web_agent = create_web_agent(model) data_agent = create_data_analysis_agent(model) media_agent = create_media_agent(model) tools = [ # DuckDuckGoSearchTool(max_results=3), # VisitWebpageTool(max_output_length=1000), web_search, perform_calculation, ] # Initialize agent with error handling try: prompt_templates = yaml.safe_load( importlib.resources.files("smolagents.prompts") .joinpath("code_agent.yaml") .read_text() ) # prompt_templates["system_prompt"] = MANAGER_SYSTEM_PROMPT agent = CodeAgent( add_base_tools=True, additional_authorized_imports=[ "json", "pandas", "numpy", "re", ], # max_steps=10, managed_agents=[web_agent, data_agent, media_agent], model=model, prompt_templates=prompt_templates, tools=tools, step_callbacks=None, verbosity_level=LogLevel.ERROR, ) agent.logger.console.width = 66 agent.visualize() tools = agent.tools print(f"Tools: {tools}") except Exception as e: logger.error(f"Failed to initialize agent: {str(e)}") raise async def process_step(state: AgentState) -> AgentState: """Process a single step of the agent's execution.""" try: # Clear previous step results before running agent.run state["current_step"] = None state["answer_text"] = None state["error"] = None steps = agent.run( task=state["task"], additional_args=None, images=None, # max_steps=1, # Process one step at a time stream=True, reset=False, # Maintain agent's internal state across process_step calls ) for step in steps: if isinstance(step, ActionStep): # Convert ActionStep to serializable dict using the correct attributes state["current_step"] = { "step_number": step.step_number, "model_output": step.model_output, "observations": step.observations, "tool_calls": [ {"name": tc.name, "arguments": tc.arguments} for tc in (step.tool_calls or []) ], "action_output": step.action_output, } logger.info(f"Processed action step {step.step_number}") logger.info(f"Step {step.step_number} details: {step}") logger.info(f"Sleeping for 60 seconds...") time.sleep(60) elif isinstance(step, FinalAnswerStep): state["answer_text"] = step.final_answer logger.info("Processed final answer") logger.debug(f"Final answer details: {step}") logger.info(f"Extracted answer text: {state['answer_text']}") # Return immediately when we get a final answer return state # If loop finishes without FinalAnswerStep, return current state return state except Exception as e: state["error"] = str(e) logger.error(f"Error during agent execution step: {str(e)}") return state def should_continue(state: AgentState) -> bool: """Determine if the agent should continue processing steps.""" # Continue if we don't have an answer_text and no error continue_execution = state.get("answer_text") is None and state.get("error") is None logger.debug( f"Checking should_continue: answer_text={state.get('answer_text') is not None}, error={state.get('error') is not None} -> Continue={continue_execution}" ) return continue_execution # Build the LangGraph graph once with persistence memory = MemorySaver() builder = StateGraph(AgentState) builder.add_node("process_step", process_step) builder.add_edge(START, "process_step") builder.add_conditional_edges( "process_step", should_continue, {True: "process_step", False: END} ) graph = builder.compile(checkpointer=memory) async def stream_execution(task: str, thread_id: str) -> AsyncIterator[AgentState]: """Stream the execution of the agent.""" if not task: raise ValueError("Task cannot be empty") logger.info(f"Initializing agent execution for task: {task}") # Initialize the state initial_state: AgentState = { "task": task, "current_step": None, "error": None, "answer_text": None, } # Pass thread_id via the config dict so the checkpointer can persist state async for state in graph.astream( initial_state, {"configurable": {"thread_id": thread_id}} ): yield state # Propagate error immediately if it occurs without an answer if state.get("error") and not state.get("answer_text"): logger.error(f"Propagating error from stream: {state['error']}") raise Exception(state["error"]) async def run_with_streaming(task: str, thread_id: str) -> dict: """Run the agent with streaming output and return the results.""" last_state = None steps = [] error = None final_answer_text = None try: logger.info(f"Starting execution run for task: {task}") async for state in stream_execution(task, thread_id): last_state = state if current_step := state.get("current_step"): if not steps or steps[-1]["step_number"] != current_step["step_number"]: steps.append(current_step) # Keep print here for direct user feedback during streaming print(f"\nStep {current_step['step_number']}:") print(f"Model Output: {current_step['model_output']}") print(f"Observations: {current_step['observations']}") if current_step.get("tool_calls"): print("Tool Calls:") for tc in current_step["tool_calls"]: print(f" - {tc['name']}: {tc['arguments']}") if current_step.get("action_output"): print(f"Action Output: {current_step['action_output']}") # After the stream is finished, process the last state logger.info("Stream finished.") if last_state: # LangGraph streams dicts where keys are node names, values are state dicts node_name = list(last_state.keys())[0] actual_state = last_state.get(node_name) if actual_state: final_answer_text = actual_state.get("answer_text") error = actual_state.get("error") logger.info( f"Final answer text extracted from last state: {final_answer_text}" ) logger.info(f"Error extracted from last state: {error}") # Ensure steps list is consistent with the final state if needed last_step_in_state = actual_state.get("current_step") if last_step_in_state and ( not steps or steps[-1]["step_number"] != last_step_in_state["step_number"] ): logger.debug("Adding last step from final state to steps list.") steps.append(last_step_in_state) else: logger.warning( "Could not find actual state dictionary within last_state." ) return {"steps": steps, "final_answer": final_answer_text, "error": error} except Exception as e: import traceback logger.error( f"Exception during run_with_streaming: {str(e)}\n{traceback.format_exc()}" ) # Attempt to return based on the last known state even if exception occurred outside stream final_answer_text = None error_msg = str(e) if last_state: node_name = list(last_state.keys())[0] actual_state = last_state.get(node_name) if actual_state: final_answer_text = actual_state.get("answer_text") return {"steps": steps, "final_answer": final_answer_text, "error": error_msg} def main(task: str, thread_id: str = str(uuid.uuid4())): # Enhance the question with instructions specific to GAIA tasks enhanced_question = f""" GAIA Benchmark Question: {task} This is a multi-step reasoning problem from the GAIA benchmark. Please solve it by: 1. Breaking the question down into clear logical steps 2. Using the appropriate specialized agents when needed: - web_agent for web searches and browsing - data_agent for data analysis and calculations - media_agent for working with images and PDFs 3. Tracking your progress through the problem 4. Providing your final answer in EXACTLY the format requested by the question IMPORTANT: GAIA questions often involve multiple steps of information gathering and reasoning. You must follow the chain of reasoning completely and provide the exact format requested. """ logger.info( f"Starting agent run from __main__ for task: '{task}' with thread_id: {thread_id}" ) result = asyncio.run(run_with_streaming(enhanced_question, thread_id)) logger.info("Agent run finished.") # Print final results # print("\n--- Execution Results ---") # print(f"Number of Steps: {len(result.get('steps', []))}") # # Optionally print step details # # for i, step in enumerate(result.get('steps', [])): # # print(f"Step {i+1} Details: {step}") # print(f"Final Answer: {result.get('final_answer') or 'Not found'}") # if err := result.get("error"): # print(f"Error: {err}") # return result.get("final_answer") logger.info(f"Result: {result}") return extract_final_answer(result) if __name__ == "__main__": # Example Usage task_to_run = "What is the capital of France?" thread_id = str(uuid.uuid4()) # Generate a unique thread ID for this run final_answer = main(task_to_run, thread_id) print(f"Final Answer: {final_answer}")