vonliechti commited on
Commit
622f963
·
verified ·
1 Parent(s): 69e37b7

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. agent.py +8 -5
  2. app.py +1 -1
  3. benchmarking.ipynb +2 -2
agent.py CHANGED
@@ -3,18 +3,21 @@ from prompts import SQUAD_REACT_CODE_SYSTEM_PROMPT
3
  from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
4
  from tools.text_to_image import TextToImageTool
5
 
6
- def get_agent():
7
- # model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
8
- model_name = "http://localhost:1234/v1"
 
9
 
10
  llm_engine = HfApiEngine(model_name)
11
 
12
  TASK_SOLVING_TOOLBOX = [
13
  SquadRetrieverTool(),
14
- SquadQueryTool(),
15
- TextToImageTool(),
16
  ]
17
 
 
 
 
18
  # Initialize the agent with both tools
19
  agent = ReactCodeAgent(
20
  tools=TASK_SOLVING_TOOLBOX,
 
3
  from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
4
  from tools.text_to_image import TextToImageTool
5
 
6
+ def get_agent(model_name = None, include_image_tools = False):
7
+ DEFAULT_MODEL_NAME = "http://localhost:1234/v1"
8
+ if model_name is None:
9
+ model_name = DEFAULT_MODEL_NAME
10
 
11
  llm_engine = HfApiEngine(model_name)
12
 
13
  TASK_SOLVING_TOOLBOX = [
14
  SquadRetrieverTool(),
15
+ SquadQueryTool()
 
16
  ]
17
 
18
+ if include_image_tools:
19
+ TASK_SOLVING_TOOLBOX.append(TextToImageTool())
20
+
21
  # Initialize the agent with both tools
22
  agent = ReactCodeAgent(
23
  tools=TASK_SOLVING_TOOLBOX,
app.py CHANGED
@@ -13,7 +13,7 @@ load_dotenv()
13
  sessions_path = "sessions.pkl"
14
  sessions = pickle.load(open(sessions_path, "rb")) if os.path.exists(sessions_path) else {}
15
 
16
- agent = get_agent()
17
 
18
  app = None
19
 
 
13
  sessions_path = "sessions.pkl"
14
  sessions = pickle.load(open(sessions_path, "rb")) if os.path.exists(sessions_path) else {}
15
 
16
+ agent = get_agent(model_name="meta-llama/Meta-Llama-3.1-8B-Instruct", include_image_tools=True)
17
 
18
  app = None
19
 
benchmarking.ipynb CHANGED
@@ -331,11 +331,11 @@
331
  "from agent import get_agent\n",
332
  "\n",
333
  "benchmarks = [\n",
334
- " (get_agent(), \"baseline\"),\n",
335
  "]\n",
336
  "\n",
337
  "for agent, name in tqdm(benchmarks):\n",
338
- " benchmark_agent(agent, dfSample, name)\n"
339
  ]
340
  },
341
  {
 
331
  "from agent import get_agent\n",
332
  "\n",
333
  "benchmarks = [\n",
334
+ " (get_agent(model_name=\"meta-llama/Meta-Llama-3.1-8B-Instruct\"), \"baseline\"),\n",
335
  "]\n",
336
  "\n",
337
  "for agent, name in tqdm(benchmarks):\n",
338
+ " benchmark_agent(agent, dfSample, name)"
339
  ]
340
  },
341
  {