"""LangGraph Agent for GAIA Assessment""" import os from typing import List, Dict, Any from dotenv import load_dotenv from langgraph.graph import START, StateGraph, MessagesState from langgraph.prebuilt import tools_condition from langgraph.prebuilt import ToolNode from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.tools import tool from langchain_groq import ChatGroq from langchain_google_genai import ChatGoogleGenerativeAI from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader from langchain_community.document_loaders import ArxivLoader load_dotenv() class GAIAAgent: """Agent for the GAIA assessment.""" def __init__(self, provider="groq"): """Initialize the agent. Args: provider: The model provider to use (groq, google) """ self.provider = provider self.tools = self._setup_tools() self.llm = self._setup_llm() self.llm_with_tools = self.llm.bind_tools(self.tools) self.graph = self._build_graph() # Load system prompt self.system_message = self._load_system_prompt() def _load_system_prompt(self): """Load the system prompt from a file.""" try: with open("system_prompt.txt", "r", encoding="utf-8") as f: system_prompt = f.read() except FileNotFoundError: # Fallback system prompt if file not found system_prompt = """You are a helpful assistant tasked with answering questions using a set of tools. Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. Your answer should only start with "FINAL ANSWER: ", then follows with the answer.""" return SystemMessage(content=system_prompt) def _setup_tools(self): """Set up the tools for the agent.""" @tool def multiply(a: int, b: int) -> int: """Multiply two numbers. Args: a: first int b: second int """ return a * b @tool def add(a: int, b: int) -> int: """Add two numbers. Args: a: first int b: second int """ return a + b @tool def subtract(a: int, b: int) -> int: """Subtract two numbers. Args: a: first int b: second int """ return a - b @tool def divide(a: int, b: int) -> float: """Divide two numbers. Args: a: first int b: second int """ if b == 0: raise ValueError("Cannot divide by zero.") return a / b @tool def modulus(a: int, b: int) -> int: """Get the modulus of two numbers. Args: a: first int b: second int """ return a % b @tool def wiki_search(query: str) -> str: """Search Wikipedia for a query and return maximum 2 results. Args: query: The search query.""" try: search_docs = WikipediaLoader(query=query, load_max_docs=2).load() formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ]) return {"wiki_results": formatted_search_docs} except Exception as e: return {"wiki_results": f"Error searching Wikipedia: {str(e)}"} @tool def web_search(query: str) -> str: """Search Tavily for a query and return maximum 3 results. Args: query: The search query.""" try: search_docs = TavilySearchResults(max_results=3).invoke(query=query) formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ]) return {"web_results": formatted_search_docs} except Exception as e: return {"web_results": f"Error searching web: {str(e)}"} @tool def arxiv_search(query: str) -> str: """Search Arxiv for a query and return maximum 3 result. Args: query: The search query.""" try: search_docs = ArxivLoader(query=query, load_max_docs=3).load() formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content[:1000]}\n' for doc in search_docs ]) return {"arxiv_results": formatted_search_docs} except Exception as e: return {"arxiv_results": f"Error searching ArXiv: {str(e)}"} return [ multiply, add, subtract, divide, modulus, wiki_search, web_search, arxiv_search, ] def _setup_llm(self): """Set up the language model.""" if self.provider == "google": api_key = os.environ.get("GOOGLE_API_KEY") if not api_key: raise ValueError("GOOGLE_API_KEY environment variable not set") return ChatGoogleGenerativeAI( model="gemini-1.5-pro", temperature=0.1, google_api_key=api_key ) elif self.provider == "groq": api_key = os.environ.get("GROQ_API_KEY") if not api_key: raise ValueError("GROQ_API_KEY environment variable not set") return ChatGroq( model="llama3-70b-8192", temperature=0.1, groq_api_key=api_key ) else: raise ValueError(f"Unsupported provider: {self.provider}") def _build_graph(self): """Build the agent graph.""" def assistant(state: MessagesState): """The assistant node in the graph.""" messages = [self.system_message] + state["messages"] return {"messages": [self.llm_with_tools.invoke(messages)]} builder = StateGraph(MessagesState) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(self.tools)) builder.add_edge(START, "assistant") builder.add_conditional_edges( "assistant", tools_condition, ) builder.add_edge("tools", "assistant") return builder.compile() def run(self, question: str) -> str: """Process a question and return the answer. Args: question: The question to answer Returns: The answer to the question """ messages = [HumanMessage(content=question)] try: result = self.graph.invoke({"messages": messages}) final_answer = result["messages"][-1].content if "FINAL ANSWER:" in final_answer: final_answer = final_answer.split("FINAL ANSWER:")[1].strip() return final_answer except Exception as e: print(f"Error running agent: {e}") return f"Error: {str(e)}"