{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install plotly kaleido datasets nbformat -U -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import datasets\n",
    "import pandas as pd\n",
    "from dotenv import load_dotenv\n",
    "from huggingface_hub import login\n",
    "\n",
    "\n",
    "load_dotenv(override=True)\n",
    "login(os.getenv(\"HF_TOKEN\"))\n",
    "\n",
    "pd.set_option(\"max_colwidth\", None)\n",
    "\n",
    "OUTPUT_DIR = \"output\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_ds = datasets.load_dataset(\"gaia-benchmark/GAIA\", \"2023_all\")[\"validation\"]\n",
    "eval_ds = eval_ds.rename_columns({\"Question\": \"question\", \"Final answer\": \"true_answer\", \"Level\": \"task\"})\n",
    "eval_df = pd.DataFrame(eval_ds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Load all results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "\n",
    "\n",
    "results = []\n",
    "for f in glob.glob(f\"{OUTPUT_DIR}/validation/*.jsonl\"):\n",
    "    df = pd.read_json(f, lines=True)\n",
    "    df[\"agent_name\"] = f.split(\"/\")[-1].split(\".\")[0]\n",
    "    results.append(df)\n",
    "\n",
    "result_df = pd.concat(results)\n",
    "result_df[\"prediction\"] = result_df[\"prediction\"].fillna(\"No prediction\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from collections import Counter\n",
    "\n",
    "from scripts.gaia_scorer import check_close_call, question_scorer\n",
    "\n",
    "\n",
    "result_df[\"is_correct\"] = result_df.apply(lambda x: question_scorer(x[\"prediction\"], x[\"true_answer\"]), axis=1)\n",
    "result_df[\"is_near_correct\"] = result_df.apply(\n",
    "    lambda x: check_close_call(x[\"prediction\"], x[\"true_answer\"], x[\"is_correct\"]),\n",
    "    axis=1,\n",
    ")\n",
    "\n",
    "result_df[\"count_steps\"] = result_df[\"intermediate_steps\"].apply(len)\n",
    "\n",
    "\n",
    "def find_attachment(question):\n",
    "    matches = eval_df.loc[eval_df[\"question\"].apply(lambda x: x in question), \"file_name\"]\n",
    "\n",
    "    if len(matches) == 0:\n",
    "        return \"Not found\"\n",
    "    file_path = matches.values[0]\n",
    "\n",
    "    if isinstance(file_path, str) and len(file_path) > 0:\n",
    "        return file_path.split(\".\")[-1]\n",
    "    else:\n",
    "        return \"None\"\n",
    "\n",
    "\n",
    "result_df[\"attachment_type\"] = result_df[\"question\"].apply(find_attachment)\n",
    "\n",
    "\n",
    "def extract_tool_calls(code):\n",
    "    regex = r\"\\b(\\w+)\\(\"\n",
    "    function_calls = [el for el in re.findall(regex, code) if el.islower()]\n",
    "\n",
    "    function_call_counter = Counter(function_calls)\n",
    "    return function_call_counter\n",
    "\n",
    "\n",
    "def sum_tool_calls(steps):\n",
    "    total_count = Counter()\n",
    "    for step in steps:\n",
    "        if \"llm_output\" in step:\n",
    "            total_count += extract_tool_calls(step[\"llm_output\"])\n",
    "\n",
    "    return total_count\n",
    "\n",
    "\n",
    "def get_durations(row):\n",
    "    # start_datetime = datetime.strptime(row['start_time'], \"%Y-%m-%d %H:%M:%S\")\n",
    "    # end_datetime = datetime.strptime(row['end_time'], \"%Y-%m-%d %H:%M:%S\")\n",
    "\n",
    "    duration_timedelta = row[\"end_time\"] - row[\"start_time\"]\n",
    "    return int(duration_timedelta.total_seconds())\n",
    "\n",
    "\n",
    "result_df[\"duration\"] = result_df.apply(get_durations, axis=1)\n",
    "# result_df[\"tool_calls\"] = result_df[\"intermediate_steps\"].apply(sum_tool_calls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df[\"agent_name\"].value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Inspect specific runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_df = result_df\n",
    "# sel_df = sel_df.loc[\n",
    "#     (result_df[\"agent_name\"].isin(list_versions))\n",
    "# ]\n",
    "sel_df = sel_df.reset_index(drop=True)\n",
    "display(sel_df[\"agent_name\"].value_counts())\n",
    "sel_df = sel_df.drop_duplicates(subset=[\"agent_name\", \"question\"])\n",
    "display(sel_df.groupby(\"agent_name\")[[\"task\"]].value_counts())\n",
    "print(\"Total length:\", len(sel_df), \"- is complete:\", len(sel_df) == 165)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(\"Average score:\", sel_df.groupby(\"agent_name\")[[\"is_correct\"]].mean().round(3))\n",
    "display(\n",
    "    sel_df.groupby([\"agent_name\", \"task\"])[[\"is_correct\", \"is_near_correct\", \"count_steps\", \"question\", \"duration\"]]\n",
    "    .agg(\n",
    "        {\n",
    "            \"is_correct\": \"mean\",\n",
    "            \"is_near_correct\": \"mean\",\n",
    "            \"count_steps\": \"mean\",\n",
    "            \"question\": \"count\",\n",
    "            \"duration\": \"mean\",\n",
    "        }\n",
    "    )\n",
    "    .rename(columns={\"question\": \"count\"})\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "\n",
    "\n",
    "cumulative_df = (\n",
    "    (\n",
    "        sel_df.groupby(\"agent_name\")[[\"is_correct\", \"is_near_correct\"]]\n",
    "        .expanding(min_periods=1, axis=0, method=\"single\")\n",
    "        .agg({\"is_correct\": \"mean\", \"is_near_correct\": \"count\"})\n",
    "        .reset_index()\n",
    "    )\n",
    "    .copy()\n",
    "    .rename(columns={\"is_near_correct\": \"index\"})\n",
    ")\n",
    "cumulative_df[\"index\"] = cumulative_df[\"index\"].astype(int) - 1\n",
    "\n",
    "\n",
    "def find_question(row):\n",
    "    try:\n",
    "        res = sel_df.loc[sel_df[\"agent_name\"] == row[\"agent_name\"], \"question\"].iloc[row[\"index\"]][:50]\n",
    "        return res\n",
    "    except Exception:\n",
    "        return \"\"\n",
    "\n",
    "\n",
    "cumulative_df[\"question\"] = cumulative_df.apply(find_question, axis=1)\n",
    "\n",
    "px.line(\n",
    "    cumulative_df,\n",
    "    color=\"agent_name\",\n",
    "    x=\"index\",\n",
    "    y=\"is_correct\",\n",
    "    hover_data=\"question\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Dive deeper into one run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sel_df = result_df.loc[result_df[\"agent_name\"] == \"o1\"]\n",
    "print(len(sel_df))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Count errors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "error_types = [\n",
    "    \"AgentParsingError\",\n",
    "    \"AgentExecutionError\",\n",
    "    \"AgentMaxIterationsError\",\n",
    "    \"AgentGenerationError\",\n",
    "]\n",
    "sel_df[error_types] = 0\n",
    "sel_df[\"Count steps\"] = np.nan\n",
    "\n",
    "\n",
    "def count_errors(row):\n",
    "    if isinstance(row[\"intermediate_steps\"], list):\n",
    "        row[\"Count steps\"] = len(row[\"intermediate_steps\"])\n",
    "        for step in row[\"intermediate_steps\"]:\n",
    "            if isinstance(step, dict) and \"error\" in step:\n",
    "                try:\n",
    "                    row[str(step[\"error\"][\"error_type\"])] += 1\n",
    "                except Exception:\n",
    "                    pass\n",
    "    return row\n",
    "\n",
    "\n",
    "sel_df = sel_df.apply(count_errors, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "\n",
    "\n",
    "aggregate_errors = (\n",
    "    sel_df.groupby([\"is_correct\"])[error_types + [\"Count steps\"]].mean().reset_index().melt(id_vars=[\"is_correct\"])\n",
    ")\n",
    "\n",
    "fig = px.bar(\n",
    "    aggregate_errors,\n",
    "    y=\"value\",\n",
    "    x=\"variable\",\n",
    "    color=\"is_correct\",\n",
    "    labels={\n",
    "        \"agent_name\": \"Model\",\n",
    "        \"task\": \"Level\",\n",
    "        \"aggregate_score\": \"Performance\",\n",
    "        \"value\": \"Average count\",\n",
    "        \"eval_score_GPT4\": \"Score\",\n",
    "    },\n",
    ")\n",
    "fig.update_layout(\n",
    "    height=500,\n",
    "    width=800,\n",
    "    barmode=\"group\",\n",
    "    bargroupgap=0.0,\n",
    ")\n",
    "fig.update_traces(textposition=\"outside\")\n",
    "fig.write_image(\"aggregate_errors.png\", scale=3)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inspect result by file extension type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(\n",
    "    result_df.groupby([\"attachment_type\"])[[\"is_correct\", \"count_steps\", \"question\"]].agg(\n",
    "        {\"is_correct\": \"mean\", \"count_steps\": \"mean\", \"question\": \"count\"}\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. Ensembling methods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "counts = result_df[\"agent_name\"].value_counts()\n",
    "long_series = result_df.loc[result_df[\"agent_name\"].isin(counts[counts > 140].index)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def majority_vote(df):\n",
    "    df = df[(df[\"prediction\"] != \"Unable to determine\") & (~df[\"prediction\"].isna()) & (df[\"prediction\"] != \"None\")]\n",
    "\n",
    "    answer_modes = df.groupby(\"question\")[\"prediction\"].agg(lambda x: x.mode()[0]).reset_index()\n",
    "    first_occurrences = (\n",
    "        df.groupby([\"question\", \"prediction\"]).agg({\"task\": \"first\", \"is_correct\": \"first\"}).reset_index()\n",
    "    )\n",
    "    result = answer_modes.merge(first_occurrences, on=[\"question\", \"prediction\"], how=\"left\")\n",
    "\n",
    "    return result\n",
    "\n",
    "\n",
    "def oracle(df):\n",
    "    def get_first_correct_or_first_wrong(group):\n",
    "        correct_answers = group[group[\"is_correct\"]]\n",
    "        if len(correct_answers) > 0:\n",
    "            return correct_answers.iloc[0]\n",
    "        return group.iloc[0]\n",
    "\n",
    "    result = df.groupby(\"question\").apply(get_first_correct_or_first_wrong)\n",
    "\n",
    "    return result.reset_index(drop=True)\n",
    "\n",
    "\n",
    "display((long_series.groupby(\"agent_name\")[\"is_correct\"].mean() * 100).round(2))\n",
    "print(f\"Majority score: {majority_vote(long_series)['is_correct'].mean() * 100:.2f}\")\n",
    "print(f\"Oracle score: {oracle(long_series)['is_correct'].mean() * 100:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Submit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_run = \"code_o1_04_february_submission5.jsonl\"\n",
    "df = pd.read_json(f\"output/validation/{agent_run}\", lines=True)\n",
    "df = df[[\"task_id\", \"prediction\", \"intermediate_steps\"]]\n",
    "df = df.rename(columns={\"prediction\": \"model_answer\", \"intermediate_steps\": \"reasoning_trace\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_json(\"submission.jsonl\", orient=\"records\", lines=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "agents",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}