"""Test the GAIA agent on a batch of questions.""" import os import json import argparse from typing import List, Dict, Any from dotenv import load_dotenv from gaia_agent import GAIAAgent def load_questions(file_path: str, max_questions: int = None) -> List[Dict[str, Any]]: """Load questions from a JSONL file. Args: file_path: Path to the JSONL file max_questions: Maximum number of questions to load Returns: List of questions """ questions = [] try: with open(file_path, "r", encoding="utf-8") as f: for line in f: if line.strip(): question_data = json.loads(line) questions.append({ "task_id": question_data.get("task_id", ""), "question": question_data.get("Question", ""), "expected_answer": question_data.get("Final answer", ""), "level": question_data.get("Level", "") }) if max_questions and len(questions) >= max_questions: break except Exception as e: print(f"Error loading questions: {e}") return [] return questions def test_batch(file_path: str, provider: str = "groq", max_questions: int = None, output_file: str = "batch_results.json"): """Test the GAIA agent on a batch of questions. Args: file_path: Path to the JSONL file containing questions provider: The model provider to use max_questions: Maximum number of questions to test output_file: Path to the output file for results """ # Load environment variables load_dotenv() # Check for required API keys if provider == "groq" and not os.getenv("GROQ_API_KEY"): print("Warning: GROQ_API_KEY not found, defaulting to Google provider") provider = "google" if provider == "google" and not os.getenv("GOOGLE_API_KEY"): print("Warning: GOOGLE_API_KEY not found, please set it in the .env file") return # Load questions questions = load_questions(file_path, max_questions) if not questions: print("No questions loaded") return print(f"Loaded {len(questions)} questions") # Initialize the agent try: agent = GAIAAgent(provider=provider) print(f"Initialized agent with provider: {provider}") except Exception as e: print(f"Error initializing agent: {e}") return # Run the agent on each question results = [] for i, question_data in enumerate(questions): question = question_data["question"] expected_answer = question_data["expected_answer"] task_id = question_data["task_id"] level = question_data["level"] print(f"[{i+1}/{len(questions)}] Testing question: {task_id}") print(f"Question: {question}") print(f"Expected answer: {expected_answer}") try: answer = agent.run(question) print(f"Agent answer: {answer}") # Check if the answer is correct is_correct = answer.strip().lower() == expected_answer.strip().lower() print(f"Correct: {is_correct}") results.append({ "task_id": task_id, "question": question, "expected_answer": expected_answer, "agent_answer": answer, "is_correct": is_correct, "level": level }) print("-" * 80) except Exception as e: print(f"Error running agent: {e}") results.append({ "task_id": task_id, "question": question, "expected_answer": expected_answer, "agent_answer": f"ERROR: {str(e)}", "is_correct": False, "level": level }) print("-" * 80) # Calculate accuracy correct_count = sum(1 for result in results if result["is_correct"]) accuracy = correct_count / len(results) if results else 0 print(f"Accuracy: {accuracy:.2%} ({correct_count}/{len(results)})") # Save results with open(output_file, "w", encoding="utf-8") as f: json.dump({ "results": results, "accuracy": accuracy, "correct_count": correct_count, "total_count": len(results) }, f, indent=2) print(f"Results saved to {output_file}") def main(): """Main function.""" parser = argparse.ArgumentParser(description="Test the GAIA agent on a batch of questions") parser.add_argument("file_path", type=str, help="Path to the JSONL file containing questions") parser.add_argument("--provider", type=str, default="groq", choices=["groq", "google", "anthropic", "openai"], help="The model provider to use") parser.add_argument("--max-questions", type=int, default=None, help="Maximum number of questions to test") parser.add_argument("--output-file", type=str, default="batch_results.json", help="Path to the output file for results") args = parser.parse_args() test_batch(args.file_path, args.provider, args.max_questions, args.output_file) if __name__ == "__main__": main()