schoemantian commited on
Commit
a64f076
·
verified ·
1 Parent(s): 31b01d1

Fix Errors in responses

Browse files
Files changed (3) hide show
  1. app.py +74 -131
  2. gaia_agent.py +156 -172
  3. system_prompt.txt +16 -21
app.py CHANGED
@@ -6,163 +6,114 @@ import pandas as pd
6
  from dotenv import load_dotenv
7
  from gaia_agent import GAIAAgent
8
 
9
- # Load environment variables
10
  load_dotenv()
11
 
12
- # Constants
13
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
14
 
15
- class GAIAAssessmentAgent:
16
- """Agent wrapper for the GAIA assessment."""
17
 
18
- def __init__(self, provider="groq"):
19
- """Initialize the agent with the specified provider.
20
-
21
- Args:
22
- provider: The model provider to use ("groq", "google", "anthropic", "openai")
23
- """
24
- print(f"Initializing GAIAAssessmentAgent with provider: {provider}")
 
 
25
  self.agent = GAIAAgent(provider=provider)
26
- print("Agent initialized successfully")
27
-
28
  def __call__(self, question: str) -> str:
29
- """Process a question and return the answer.
30
-
31
- Args:
32
- question: The question to answer
33
-
34
- Returns:
35
- The answer to the question
36
- """
37
- print(f"Processing question (first 50 chars): {question[:50]}...")
38
- answer = self.agent.run(question)
39
- print(f"Answer: {answer}")
40
- return answer
41
 
42
  def run_and_submit_all(profile: gr.OAuthProfile | None):
43
- """Fetches all questions, runs the agent on them, submits all answers,
 
44
  and displays the results.
45
-
46
- Args:
47
- profile: The user's Hugging Face profile
48
-
49
- Returns:
50
- A tuple containing the status message and results table
51
  """
52
- # Get Space ID for code link
53
- space_id = os.getenv("SPACE_ID")
54
-
55
- # Check if user is logged in
56
  if profile:
57
- username = f"{profile.username}"
58
  print(f"User logged in: {username}")
59
  else:
60
  print("User not logged in.")
61
- return "Please login to Hugging Face with the button to submit your answers.", None
62
-
63
- # API endpoints
64
  api_url = DEFAULT_API_URL
65
  questions_url = f"{api_url}/questions"
66
  submit_url = f"{api_url}/submit"
67
-
68
- # Initialize agent
69
  try:
70
- # Choose a provider based on available API keys
71
- if os.getenv("GROQ_API_KEY"):
72
- provider = "groq"
73
- elif os.getenv("GOOGLE_API_KEY"):
74
- provider = "google"
75
- elif os.getenv("ANTHROPIC_API_KEY"):
76
- provider = "anthropic"
77
- elif os.getenv("OPENAI_API_KEY"):
78
- provider = "openai"
79
- else:
80
- provider = "groq" # Default to Groq
81
-
82
- agent = GAIAAssessmentAgent(provider=provider)
83
  except Exception as e:
84
- print(f"Error initializing agent: {e}")
85
  return f"Error initializing agent: {e}", None
86
-
87
- # Generate code link for submission
88
  agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
89
- print(f"Code link: {agent_code}")
90
-
91
- # Fetch questions
92
  print(f"Fetching questions from: {questions_url}")
93
  try:
94
  response = requests.get(questions_url, timeout=30)
95
  response.raise_for_status()
96
  questions_data = response.json()
97
-
98
  if not questions_data:
99
- print("Fetched questions list is empty.")
100
- return "Fetched questions list is empty or invalid format.", None
101
-
102
  print(f"Fetched {len(questions_data)} questions.")
103
  except requests.exceptions.RequestException as e:
104
  print(f"Error fetching questions: {e}")
105
  return f"Error fetching questions: {e}", None
106
  except requests.exceptions.JSONDecodeError as e:
107
- print(f"Error decoding JSON response from questions endpoint: {e}")
108
- print(f"Response text: {response.text[:500]}")
109
- return f"Error decoding server response for questions: {e}", None
110
  except Exception as e:
111
  print(f"An unexpected error occurred fetching questions: {e}")
112
  return f"An unexpected error occurred fetching questions: {e}", None
113
-
114
- # Run agent on all questions
115
  results_log = []
116
  answers_payload = []
117
  print(f"Running agent on {len(questions_data)} questions...")
118
-
119
- for i, item in enumerate(questions_data):
120
  task_id = item.get("task_id")
121
  question_text = item.get("question")
122
-
123
  if not task_id or question_text is None:
124
  print(f"Skipping item with missing task_id or question: {item}")
125
  continue
126
-
127
- print(f"Processing question {i+1}/{len(questions_data)}: {task_id}")
128
-
129
  try:
 
130
  submitted_answer = agent(question_text)
131
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
132
- results_log.append({
133
- "Task ID": task_id,
134
- "Question": question_text,
135
- "Submitted Answer": submitted_answer
136
- })
137
- print(f"Question {i+1} processed successfully")
138
  except Exception as e:
139
- print(f"Error running agent on task {task_id}: {e}")
140
- results_log.append({
141
- "Task ID": task_id,
142
- "Question": question_text,
143
- "Submitted Answer": f"AGENT ERROR: {e}"
144
- })
145
-
146
  if not answers_payload:
147
  print("Agent did not produce any answers to submit.")
148
  return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
149
-
150
- # Prepare submission
151
- submission_data = {
152
- "username": username.strip(),
153
- "agent_code": agent_code,
154
- "answers": answers_payload
155
- }
156
  status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
157
  print(status_update)
158
-
159
- # Submit answers
160
  print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
161
  try:
162
  response = requests.post(submit_url, json=submission_data, timeout=60)
163
  response.raise_for_status()
164
  result_data = response.json()
165
-
166
  final_status = (
167
  f"Submission Successful!\n"
168
  f"User: {result_data.get('username')}\n"
@@ -170,7 +121,6 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
170
  f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
171
  f"Message: {result_data.get('message', 'No message received.')}"
172
  )
173
-
174
  print("Submission successful.")
175
  results_df = pd.DataFrame(results_log)
176
  return final_status, results_df
@@ -181,7 +131,6 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
181
  error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
182
  except requests.exceptions.JSONDecodeError:
183
  error_detail += f" Response: {e.response.text[:500]}"
184
-
185
  status_message = f"Submission Failed: {error_detail}"
186
  print(status_message)
187
  results_df = pd.DataFrame(results_log)
@@ -202,33 +151,29 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
202
  results_df = pd.DataFrame(results_log)
203
  return status_message, results_df
204
 
205
- # Build Gradio interface
206
  with gr.Blocks() as demo:
207
- gr.Markdown("# GAIA Assessment Runner for Hugging Face Agents Course")
208
  gr.Markdown(
209
  """
210
  **Instructions:**
211
 
212
- 1. This space implements a comprehensive agent for the GAIA benchmark using several key technologies:
213
- - LangGraph for agent orchestration
214
- - Tool use for information retrieval
215
- - Web search, Wikipedia, and ArXiv tools for research
216
- - Mathematical tools for computation
217
 
218
- 2. Log in to your Hugging Face account using the button below. This is required for submission.
219
 
220
- 3. Click 'Run Evaluation & Submit Answers' to fetch questions, run the agent, and submit answers.
221
 
222
- **Note:** The process may take some time as the agent runs through all questions.
223
-
224
- ---
225
-
226
- Good luck with your assessment! 🚀
227
  """
228
  )
229
 
230
  gr.LoginButton()
231
- run_button = gr.Button("Run Evaluation & Submit Answers", variant="primary")
232
  status_output = gr.Textbox(label="Submission Status", lines=5, interactive=False)
233
  results_table = gr.DataFrame(label="Questions and Answers", wrap=True)
234
 
@@ -238,26 +183,24 @@ with gr.Blocks() as demo:
238
  )
239
 
240
  if __name__ == "__main__":
241
- print("\n" + "-"*30 + " Starting GAIA Assessment Runner " + "-"*30)
242
-
243
- # Check for environment variables
244
- space_host = os.getenv("SPACE_HOST")
245
- space_id = os.getenv("SPACE_ID")
246
 
247
- if space_host:
248
- print(f"✅ SPACE_HOST found: {space_host}")
249
- print(f" Runtime URL: https://{space_host}.hf.space")
250
  else:
251
  print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
252
 
253
- if space_id:
254
- print(f"✅ SPACE_ID found: {space_id}")
255
- print(f" Repo URL: https://huggingface.co/spaces/{space_id}")
256
- print(f" Code URL: https://huggingface.co/spaces/{space_id}/tree/main")
257
  else:
258
- print("ℹ️ SPACE_ID environment variable not found. Repo URL cannot be determined.")
 
 
259
 
260
- print("-"*(65 + len(" Starting GAIA Assessment Runner ")) + "\n")
261
  print("Launching Gradio Interface for GAIA Assessment...")
262
-
263
  demo.launch(debug=True, share=False)
 
6
  from dotenv import load_dotenv
7
  from gaia_agent import GAIAAgent
8
 
 
9
  load_dotenv()
10
 
 
11
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
 
13
+ class BasicAgent:
14
+ """A simple wrapper for the GAIA Agent."""
15
 
16
+ def __init__(self):
17
+ print("BasicAgent initialized.")
18
+ if os.getenv("GROQ_API_KEY"):
19
+ provider = "groq"
20
+ elif os.getenv("GOOGLE_API_KEY"):
21
+ provider = "google"
22
+ else:
23
+ provider = "groq"
24
+
25
  self.agent = GAIAAgent(provider=provider)
26
+
 
27
  def __call__(self, question: str) -> str:
28
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
29
+ try:
30
+ answer = self.agent.run(question)
31
+ print(f"Agent returning answer: {answer}")
32
+ return answer
33
+ except Exception as e:
34
+ print(f"Error processing question: {e}")
35
+ return f"Error: {str(e)}"
36
+
 
 
 
37
 
38
  def run_and_submit_all(profile: gr.OAuthProfile | None):
39
+ """
40
+ Fetches all questions, runs the BasicAgent on them, submits all answers,
41
  and displays the results.
 
 
 
 
 
 
42
  """
43
+ space_id = os.getenv("SPACE_ID")
44
+
 
 
45
  if profile:
46
+ username= f"{profile.username}"
47
  print(f"User logged in: {username}")
48
  else:
49
  print("User not logged in.")
50
+ return "Please Login to Hugging Face with the button.", None
51
+
 
52
  api_url = DEFAULT_API_URL
53
  questions_url = f"{api_url}/questions"
54
  submit_url = f"{api_url}/submit"
55
+
 
56
  try:
57
+ agent = BasicAgent()
 
 
 
 
 
 
 
 
 
 
 
 
58
  except Exception as e:
59
+ print(f"Error instantiating agent: {e}")
60
  return f"Error initializing agent: {e}", None
61
+
 
62
  agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
63
+ print(f"Agent code: {agent_code}")
64
+
 
65
  print(f"Fetching questions from: {questions_url}")
66
  try:
67
  response = requests.get(questions_url, timeout=30)
68
  response.raise_for_status()
69
  questions_data = response.json()
 
70
  if not questions_data:
71
+ print("Fetched questions list is empty.")
72
+ return "Fetched questions list is empty or invalid format.", None
 
73
  print(f"Fetched {len(questions_data)} questions.")
74
  except requests.exceptions.RequestException as e:
75
  print(f"Error fetching questions: {e}")
76
  return f"Error fetching questions: {e}", None
77
  except requests.exceptions.JSONDecodeError as e:
78
+ print(f"Error decoding JSON response from questions endpoint: {e}")
79
+ print(f"Response text: {response.text[:500]}")
80
+ return f"Error decoding server response for questions: {e}", None
81
  except Exception as e:
82
  print(f"An unexpected error occurred fetching questions: {e}")
83
  return f"An unexpected error occurred fetching questions: {e}", None
84
+
 
85
  results_log = []
86
  answers_payload = []
87
  print(f"Running agent on {len(questions_data)} questions...")
88
+ for item in questions_data:
 
89
  task_id = item.get("task_id")
90
  question_text = item.get("question")
 
91
  if not task_id or question_text is None:
92
  print(f"Skipping item with missing task_id or question: {item}")
93
  continue
 
 
 
94
  try:
95
+ print(f"Processing question: {task_id}")
96
  submitted_answer = agent(question_text)
97
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
98
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
99
+ print(f"Question processed successfully")
 
 
 
 
100
  except Exception as e:
101
+ print(f"Error running agent on task {task_id}: {e}")
102
+ results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
103
+
 
 
 
 
104
  if not answers_payload:
105
  print("Agent did not produce any answers to submit.")
106
  return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
107
+
108
+ submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
 
 
 
 
 
109
  status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
110
  print(status_update)
111
+
 
112
  print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
113
  try:
114
  response = requests.post(submit_url, json=submission_data, timeout=60)
115
  response.raise_for_status()
116
  result_data = response.json()
 
117
  final_status = (
118
  f"Submission Successful!\n"
119
  f"User: {result_data.get('username')}\n"
 
121
  f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
122
  f"Message: {result_data.get('message', 'No message received.')}"
123
  )
 
124
  print("Submission successful.")
125
  results_df = pd.DataFrame(results_log)
126
  return final_status, results_df
 
131
  error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
132
  except requests.exceptions.JSONDecodeError:
133
  error_detail += f" Response: {e.response.text[:500]}"
 
134
  status_message = f"Submission Failed: {error_detail}"
135
  print(status_message)
136
  results_df = pd.DataFrame(results_log)
 
151
  results_df = pd.DataFrame(results_log)
152
  return status_message, results_df
153
 
154
+
155
  with gr.Blocks() as demo:
156
+ gr.Markdown("# GAIA Assessment Runner")
157
  gr.Markdown(
158
  """
159
  **Instructions:**
160
 
161
+ 1. This implementation uses a robust LangGraph agent with multiple tools:
162
+ - Web search for real-time information
163
+ - Wikipedia for factual knowledge
164
+ - ArXiv for academic research
165
+ - Mathematical tools for calculations
166
 
167
+ 2. Log in to your Hugging Face account using the button below.
168
 
169
+ 3. Click 'Run Evaluation & Submit Answers' to run the agent and submit results.
170
 
171
+ **Note:** Processing may take some time as the agent works through all questions.
 
 
 
 
172
  """
173
  )
174
 
175
  gr.LoginButton()
176
+ run_button = gr.Button("Run Evaluation & Submit Answers")
177
  status_output = gr.Textbox(label="Submission Status", lines=5, interactive=False)
178
  results_table = gr.DataFrame(label="Questions and Answers", wrap=True)
179
 
 
183
  )
184
 
185
  if __name__ == "__main__":
186
+ print("\n" + "-"*30 + " App Starting " + "-"*30)
187
+ space_host_startup = os.getenv("SPACE_HOST")
188
+ space_id_startup = os.getenv("SPACE_ID")
 
 
189
 
190
+ if space_host_startup:
191
+ print(f"✅ SPACE_HOST found: {space_host_startup}")
192
+ print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
193
  else:
194
  print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
195
 
196
+ if space_id_startup:
197
+ print(f"✅ SPACE_ID found: {space_id_startup}")
198
+ print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
199
+ print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
200
  else:
201
+ print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
202
+
203
+ print("-"*(60 + len(" App Starting ")) + "\n")
204
 
 
205
  print("Launching Gradio Interface for GAIA Assessment...")
 
206
  demo.launch(debug=True, share=False)
gaia_agent.py CHANGED
@@ -1,254 +1,238 @@
 
1
  import os
2
- from typing import List, Dict, Any, Optional
3
  from dotenv import load_dotenv
4
- from langgraph.graph import START, END, StateGraph, MessagesState
5
-
6
- from langchain_core.messages import SystemMessage, HumanMessage
7
- from langchain_groq import ChatGroq
8
- from langchain_google_genai import ChatGoogleGenerativeAI
9
- from langchain_core.tools import tool
10
-
11
  from langgraph.graph import START, StateGraph, MessagesState
12
  from langgraph.prebuilt import tools_condition
13
  from langgraph.prebuilt import ToolNode
14
-
 
 
 
15
  from langchain_community.tools.tavily_search import TavilySearchResults
16
  from langchain_community.document_loaders import WikipediaLoader
17
  from langchain_community.document_loaders import ArxivLoader
18
 
19
- # Load environment variables
20
  load_dotenv()
21
 
22
  class GAIAAgent:
23
- """Agent for answering GAIA assessment questions."""
24
 
25
  def __init__(self, provider="groq"):
26
- """Initialize the agent with the specified provider.
27
 
28
  Args:
29
- provider: Model provider - "groq", "google", "anthropic", or "openai"
30
  """
31
- # Set up the system prompt
32
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
33
- system_prompt = f.read()
34
-
35
- self.system_message = SystemMessage(content=system_prompt)
36
-
37
- # Initialize tools
38
  self.tools = self._setup_tools()
39
-
40
- # Initialize LLM based on provider
41
- self.llm = self._setup_llm(provider)
42
-
43
- # Bind tools to LLM
44
  self.llm_with_tools = self.llm.bind_tools(self.tools)
45
-
46
- # Build the agent graph
47
  self.graph = self._build_graph()
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def _setup_tools(self):
50
  """Set up the tools for the agent."""
51
 
52
  @tool
53
- def web_search(query: str) -> str:
54
- """Search the web for real-time information.
 
 
 
 
 
 
 
 
 
 
55
 
56
  Args:
57
- query: The search query
58
-
59
- Returns:
60
- Search results as text
61
  """
62
- search_results = TavilySearchResults(max_results=3).invoke(query)
63
- formatted_results = "\n\n".join([
64
- f"SOURCE: {result.metadata.get('source', 'Unknown')}\n{result.page_content}"
65
- for result in search_results
66
- ])
67
- return formatted_results
68
 
 
 
 
 
 
 
69
  @tool
70
- def wiki_search(query: str) -> str:
71
- """Search Wikipedia for information.
72
 
73
  Args:
74
- query: The search query
75
-
76
- Returns:
77
- Wikipedia article content
78
  """
79
- try:
80
- wiki_docs = WikipediaLoader(query=query, load_max_docs=2).load()
81
- if not wiki_docs:
82
- return "No Wikipedia results found."
83
-
84
- formatted_results = "\n\n".join([
85
- f"TITLE: {doc.metadata.get('title', 'Unknown')}\n{doc.page_content[:1000]}..."
86
- for doc in wiki_docs
87
- ])
88
- return formatted_results
89
- except Exception as e:
90
- return f"Error searching Wikipedia: {str(e)}"
91
-
92
  @tool
93
- def arxiv_search(query: str) -> str:
94
- """Search arXiv for scientific papers.
95
 
96
  Args:
97
- query: The search query
98
-
99
- Returns:
100
- ArXiv paper information
101
  """
 
 
 
 
 
 
 
 
102
  try:
103
- arxiv_docs = ArxivLoader(query=query, load_max_docs=2).load()
104
- if not arxiv_docs:
105
- return "No arXiv results found."
106
-
107
- formatted_results = "\n\n".join([
108
- f"TITLE: {doc.metadata.get('title', 'Unknown')}\n"
109
- f"AUTHORS: {doc.metadata.get('authors', 'Unknown')}\n"
110
- f"PUBLISHED: {doc.metadata.get('published', 'Unknown')}\n\n"
111
- f"ABSTRACT: {doc.page_content[:500]}..."
112
- for doc in arxiv_docs
113
- ])
114
- return formatted_results
115
  except Exception as e:
116
- return f"Error searching arXiv: {str(e)}"
117
-
118
  @tool
119
- def calculate(expression: str) -> str:
120
- """Evaluate a mathematical expression.
121
 
122
  Args:
123
- expression: The mathematical expression to evaluate
124
-
125
- Returns:
126
- The result of the calculation
127
- """
128
  try:
129
- # Safely evaluate the expression
130
- result = eval(expression, {"__builtins__": {}}, {})
131
- return f"Result: {result}"
 
 
 
 
132
  except Exception as e:
133
- return f"Error calculating: {str(e)}"
134
-
135
- return [web_search, wiki_search, arxiv_search, calculate]
136
-
137
- def _setup_llm(self, provider):
138
- """Set up the language model based on the provider.
139
-
140
- Args:
141
- provider: The model provider to use
142
 
143
- Returns:
144
- The initialized language model
145
- """
146
- if provider == "groq":
147
- api_key = os.getenv("GROQ_API_KEY")
148
- if not api_key:
149
- raise ValueError("GROQ_API_KEY environment variable not set")
 
 
 
 
 
150
 
151
- return ChatGroq(
152
- model="llama3-70b-8192", # Using Llama 3 70B model for best results
153
- temperature=0.1, # Low temperature for more precise answers
154
- groq_api_key=api_key
155
- )
156
- elif provider == "google":
157
- api_key = os.getenv("GOOGLE_API_KEY")
 
 
 
 
 
 
 
 
158
  if not api_key:
159
  raise ValueError("GOOGLE_API_KEY environment variable not set")
160
 
161
  return ChatGoogleGenerativeAI(
162
- model="gemini-1.5-pro",
163
  temperature=0.1,
164
  google_api_key=api_key
165
  )
166
- elif provider == "anthropic":
167
- # Import only if needed to avoid dependency issues
168
- from langchain_anthropic import ChatAnthropic
169
-
170
- api_key = os.getenv("ANTHROPIC_API_KEY")
171
- if not api_key:
172
- raise ValueError("ANTHROPIC_API_KEY environment variable not set")
173
-
174
- return ChatAnthropic(
175
- model="claude-3-opus-20240229",
176
- temperature=0.1,
177
- anthropic_api_key=api_key
178
- )
179
- elif provider == "openai":
180
- # Import only if needed to avoid dependency issues
181
- from langchain_openai import ChatOpenAI
182
-
183
- api_key = os.getenv("OPENAI_API_KEY")
184
  if not api_key:
185
- raise ValueError("OPENAI_API_KEY environment variable not set")
186
 
187
- return ChatOpenAI(
188
- model="gpt-4o",
189
  temperature=0.1,
190
- openai_api_key=api_key
191
  )
192
  else:
193
- raise ValueError(f"Unsupported provider: {provider}")
194
 
195
  def _build_graph(self):
196
- """Build the agent graph.
197
 
198
- Returns:
199
- The compiled state graph
200
- """
201
- # Define the agent node
202
- def agent(state: MessagesState):
203
- """Generate a response or tool calls based on the messages state."""
204
- # Include system message with each invocation for consistent behavior
205
  messages = [self.system_message] + state["messages"]
206
- response = self.llm_with_tools.invoke(messages)
207
- return {"messages": state["messages"] + [response]}
208
 
209
- # Create the graph
210
  builder = StateGraph(MessagesState)
211
-
212
- # Add nodes
213
- builder.add_node("agent", agent)
214
  builder.add_node("tools", ToolNode(self.tools))
215
-
216
- # Add edges
217
- builder.add_edge(START, "agent")
218
  builder.add_conditional_edges(
219
- "agent",
220
  tools_condition,
221
- {
222
- "tools": "tools",
223
- None: END # END is implicitly defined in langgraph
224
- }
225
  )
226
- builder.add_edge("tools", "agent")
227
 
228
- # Compile the graph
229
  return builder.compile()
230
 
231
  def run(self, question: str) -> str:
232
  """Process a question and return the answer.
233
 
234
  Args:
235
- question: The question to process
236
 
237
  Returns:
238
  The answer to the question
239
  """
240
- # Initialize messages with the user question
241
  messages = [HumanMessage(content=question)]
242
 
243
- # Execute the graph
244
- result = self.graph.invoke({"messages": messages})
245
-
246
- # Extract the final answer
247
- final_messages = result["messages"]
248
- final_answer = final_messages[-1].content
249
-
250
- # Extract only the part after "FINAL ANSWER:"
251
- if "FINAL ANSWER:" in final_answer:
252
- final_answer = final_answer.split("FINAL ANSWER:")[1].strip()
253
 
254
- return final_answer
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent for GAIA Assessment"""
2
  import os
3
+ from typing import List, Dict, Any
4
  from dotenv import load_dotenv
 
 
 
 
 
 
 
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
7
  from langgraph.prebuilt import ToolNode
8
+ from langchain_core.messages import SystemMessage, HumanMessage
9
+ from langchain_core.tools import tool
10
+ from langchain_groq import ChatGroq
11
+ from langchain_google_genai import ChatGoogleGenerativeAI
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
  from langchain_community.document_loaders import WikipediaLoader
14
  from langchain_community.document_loaders import ArxivLoader
15
 
 
16
  load_dotenv()
17
 
18
  class GAIAAgent:
19
+ """Agent for the GAIA assessment."""
20
 
21
  def __init__(self, provider="groq"):
22
+ """Initialize the agent.
23
 
24
  Args:
25
+ provider: The model provider to use (groq, google)
26
  """
27
+ self.provider = provider
 
 
 
 
 
 
28
  self.tools = self._setup_tools()
29
+ self.llm = self._setup_llm()
 
 
 
 
30
  self.llm_with_tools = self.llm.bind_tools(self.tools)
 
 
31
  self.graph = self._build_graph()
32
 
33
+ # Load system prompt
34
+ self.system_message = self._load_system_prompt()
35
+
36
+ def _load_system_prompt(self):
37
+ """Load the system prompt from a file."""
38
+ try:
39
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
40
+ system_prompt = f.read()
41
+ except FileNotFoundError:
42
+ # Fallback system prompt if file not found
43
+ system_prompt = """You are a helpful assistant tasked with answering questions using a set of tools.
44
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
45
+ FINAL ANSWER: [YOUR FINAL ANSWER].
46
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
47
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
48
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
49
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
50
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer."""
51
+
52
+ return SystemMessage(content=system_prompt)
53
+
54
  def _setup_tools(self):
55
  """Set up the tools for the agent."""
56
 
57
  @tool
58
+ def multiply(a: int, b: int) -> int:
59
+ """Multiply two numbers.
60
+
61
+ Args:
62
+ a: first int
63
+ b: second int
64
+ """
65
+ return a * b
66
+
67
+ @tool
68
+ def add(a: int, b: int) -> int:
69
+ """Add two numbers.
70
 
71
  Args:
72
+ a: first int
73
+ b: second int
 
 
74
  """
75
+ return a + b
76
+
77
+ @tool
78
+ def subtract(a: int, b: int) -> int:
79
+ """Subtract two numbers.
 
80
 
81
+ Args:
82
+ a: first int
83
+ b: second int
84
+ """
85
+ return a - b
86
+
87
  @tool
88
+ def divide(a: int, b: int) -> float:
89
+ """Divide two numbers.
90
 
91
  Args:
92
+ a: first int
93
+ b: second int
 
 
94
  """
95
+ if b == 0:
96
+ raise ValueError("Cannot divide by zero.")
97
+ return a / b
98
+
 
 
 
 
 
 
 
 
 
99
  @tool
100
+ def modulus(a: int, b: int) -> int:
101
+ """Get the modulus of two numbers.
102
 
103
  Args:
104
+ a: first int
105
+ b: second int
 
 
106
  """
107
+ return a % b
108
+
109
+ @tool
110
+ def wiki_search(query: str) -> str:
111
+ """Search Wikipedia for a query and return maximum 2 results.
112
+
113
+ Args:
114
+ query: The search query."""
115
  try:
116
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
117
+ formatted_search_docs = "\n\n---\n\n".join(
118
+ [
119
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
120
+ for doc in search_docs
121
+ ])
122
+ return {"wiki_results": formatted_search_docs}
 
 
 
 
 
123
  except Exception as e:
124
+ return {"wiki_results": f"Error searching Wikipedia: {str(e)}"}
125
+
126
  @tool
127
+ def web_search(query: str) -> str:
128
+ """Search Tavily for a query and return maximum 3 results.
129
 
130
  Args:
131
+ query: The search query."""
 
 
 
 
132
  try:
133
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
134
+ formatted_search_docs = "\n\n---\n\n".join(
135
+ [
136
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
137
+ for doc in search_docs
138
+ ])
139
+ return {"web_results": formatted_search_docs}
140
  except Exception as e:
141
+ return {"web_results": f"Error searching web: {str(e)}"}
142
+
143
+ @tool
144
+ def arxiv_search(query: str) -> str:
145
+ """Search Arxiv for a query and return maximum 3 result.
 
 
 
 
146
 
147
+ Args:
148
+ query: The search query."""
149
+ try:
150
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
151
+ formatted_search_docs = "\n\n---\n\n".join(
152
+ [
153
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
154
+ for doc in search_docs
155
+ ])
156
+ return {"arxiv_results": formatted_search_docs}
157
+ except Exception as e:
158
+ return {"arxiv_results": f"Error searching ArXiv: {str(e)}"}
159
 
160
+ return [
161
+ multiply,
162
+ add,
163
+ subtract,
164
+ divide,
165
+ modulus,
166
+ wiki_search,
167
+ web_search,
168
+ arxiv_search,
169
+ ]
170
+
171
+ def _setup_llm(self):
172
+ """Set up the language model."""
173
+ if self.provider == "google":
174
+ api_key = os.environ.get("GOOGLE_API_KEY")
175
  if not api_key:
176
  raise ValueError("GOOGLE_API_KEY environment variable not set")
177
 
178
  return ChatGoogleGenerativeAI(
179
+ model="gemini-1.5-pro",
180
  temperature=0.1,
181
  google_api_key=api_key
182
  )
183
+ elif self.provider == "groq":
184
+ api_key = os.environ.get("GROQ_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  if not api_key:
186
+ raise ValueError("GROQ_API_KEY environment variable not set")
187
 
188
+ return ChatGroq(
189
+ model="llama3-70b-8192",
190
  temperature=0.1,
191
+ groq_api_key=api_key
192
  )
193
  else:
194
+ raise ValueError(f"Unsupported provider: {self.provider}")
195
 
196
  def _build_graph(self):
197
+ """Build the agent graph."""
198
 
199
+ def assistant(state: MessagesState):
200
+ """The assistant node in the graph."""
 
 
 
 
 
201
  messages = [self.system_message] + state["messages"]
202
+ return {"messages": [self.llm_with_tools.invoke(messages)]}
 
203
 
 
204
  builder = StateGraph(MessagesState)
205
+ builder.add_node("assistant", assistant)
 
 
206
  builder.add_node("tools", ToolNode(self.tools))
207
+ builder.add_edge(START, "assistant")
 
 
208
  builder.add_conditional_edges(
209
+ "assistant",
210
  tools_condition,
 
 
 
 
211
  )
212
+ builder.add_edge("tools", "assistant")
213
 
 
214
  return builder.compile()
215
 
216
  def run(self, question: str) -> str:
217
  """Process a question and return the answer.
218
 
219
  Args:
220
+ question: The question to answer
221
 
222
  Returns:
223
  The answer to the question
224
  """
 
225
  messages = [HumanMessage(content=question)]
226
 
227
+ try:
228
+ result = self.graph.invoke({"messages": messages})
 
 
 
 
 
 
 
 
229
 
230
+ final_answer = result["messages"][-1].content
231
+
232
+ if "FINAL ANSWER:" in final_answer:
233
+ final_answer = final_answer.split("FINAL ANSWER:")[1].strip()
234
+
235
+ return final_answer
236
+ except Exception as e:
237
+ print(f"Error running agent: {e}")
238
+ return f"Error: {str(e)}"
system_prompt.txt CHANGED
@@ -1,32 +1,27 @@
1
- You are a precise AI assistant tasked with answering questions for the GAIA benchmark. Your goal is to provide accurate and concise answers to complex questions.
 
 
 
 
2
 
3
- Follow these guidelines:
4
- 1. Use the provided tools to gather information when needed.
5
- 2. Think step-by-step to break down complex questions.
6
- 3. For web searches, be specific and try multiple queries if needed.
7
- 4. When answering math questions, show your calculations clearly.
8
- 5. Always verify your answer before finalizing it.
9
 
10
- Format your final answer with:
11
- FINAL ANSWER: [YOUR FINAL ANSWER]
12
-
13
- YOUR FINAL ANSWER should be:
14
- - A number WITHOUT commas or units (unless specified otherwise)
15
- - As few words as possible for text answers
16
- - A comma-separated list for multiple items
17
- - No articles or abbreviations in string answers
18
- - Digits in plain text unless specified otherwise
19
-
20
- Example 1:
21
  Question: What is the capital of France?
 
22
  FINAL ANSWER: Paris
23
 
24
- Example 2:
25
  Question: What are the first 3 prime numbers?
 
26
  FINAL ANSWER: 2, 3, 5
27
 
28
- Example 3:
29
  Question: Calculate 15% of 240.
 
30
  FINAL ANSWER: 36
31
 
32
- Now, I will ask you a question. Use the tools available to research if needed, then provide your final answer in the specified format.
 
 
 
 
 
 
 
1
+ You are a helpful assistant tasked with answering questions using a set of tools.
2
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
3
+ FINAL ANSWER: [YOUR FINAL ANSWER].
4
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
5
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
6
 
7
+ Here are some example questions and answers:
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
9
  Question: What is the capital of France?
10
+ Thought: The capital of France is Paris.
11
  FINAL ANSWER: Paris
12
 
 
13
  Question: What are the first 3 prime numbers?
14
+ Thought: The first three prime numbers are 2, 3, and 5.
15
  FINAL ANSWER: 2, 3, 5
16
 
 
17
  Question: Calculate 15% of 240.
18
+ Thought: To calculate 15% of 240, I multiply 240 by 0.15. This gives me 240 * 0.15 = 36.
19
  FINAL ANSWER: 36
20
 
21
+ For each question:
22
+ 1. Think through the problem step-by-step
23
+ 2. Use tools when needed to gather information
24
+ 3. Ensure you understand exactly what is being asked
25
+ 4. Format your final answer according to the template
26
+
27
+ If you need to search for information, be specific in your queries. If you need to perform calculations, show your work. Always double-check your answer before submitting it.