Upload folder using huggingface_hub
Browse files- agent.py +8 -5
- app.py +1 -1
- benchmarking.ipynb +2 -2
agent.py
CHANGED
@@ -3,18 +3,21 @@ from prompts import SQUAD_REACT_CODE_SYSTEM_PROMPT
|
|
3 |
from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
|
4 |
from tools.text_to_image import TextToImageTool
|
5 |
|
6 |
-
def get_agent():
|
7 |
-
|
8 |
-
model_name
|
|
|
9 |
|
10 |
llm_engine = HfApiEngine(model_name)
|
11 |
|
12 |
TASK_SOLVING_TOOLBOX = [
|
13 |
SquadRetrieverTool(),
|
14 |
-
SquadQueryTool()
|
15 |
-
TextToImageTool(),
|
16 |
]
|
17 |
|
|
|
|
|
|
|
18 |
# Initialize the agent with both tools
|
19 |
agent = ReactCodeAgent(
|
20 |
tools=TASK_SOLVING_TOOLBOX,
|
|
|
3 |
from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
|
4 |
from tools.text_to_image import TextToImageTool
|
5 |
|
6 |
+
def get_agent(model_name = None, include_image_tools = False):
|
7 |
+
DEFAULT_MODEL_NAME = "http://localhost:1234/v1"
|
8 |
+
if model_name is None:
|
9 |
+
model_name = DEFAULT_MODEL_NAME
|
10 |
|
11 |
llm_engine = HfApiEngine(model_name)
|
12 |
|
13 |
TASK_SOLVING_TOOLBOX = [
|
14 |
SquadRetrieverTool(),
|
15 |
+
SquadQueryTool()
|
|
|
16 |
]
|
17 |
|
18 |
+
if include_image_tools:
|
19 |
+
TASK_SOLVING_TOOLBOX.append(TextToImageTool())
|
20 |
+
|
21 |
# Initialize the agent with both tools
|
22 |
agent = ReactCodeAgent(
|
23 |
tools=TASK_SOLVING_TOOLBOX,
|
app.py
CHANGED
@@ -13,7 +13,7 @@ load_dotenv()
|
|
13 |
sessions_path = "sessions.pkl"
|
14 |
sessions = pickle.load(open(sessions_path, "rb")) if os.path.exists(sessions_path) else {}
|
15 |
|
16 |
-
agent = get_agent()
|
17 |
|
18 |
app = None
|
19 |
|
|
|
13 |
sessions_path = "sessions.pkl"
|
14 |
sessions = pickle.load(open(sessions_path, "rb")) if os.path.exists(sessions_path) else {}
|
15 |
|
16 |
+
agent = get_agent(model_name="meta-llama/Meta-Llama-3.1-8B-Instruct", include_image_tools=True)
|
17 |
|
18 |
app = None
|
19 |
|
benchmarking.ipynb
CHANGED
@@ -331,11 +331,11 @@
|
|
331 |
"from agent import get_agent\n",
|
332 |
"\n",
|
333 |
"benchmarks = [\n",
|
334 |
-
" (get_agent(), \"baseline\"),\n",
|
335 |
"]\n",
|
336 |
"\n",
|
337 |
"for agent, name in tqdm(benchmarks):\n",
|
338 |
-
" benchmark_agent(agent, dfSample, name)
|
339 |
]
|
340 |
},
|
341 |
{
|
|
|
331 |
"from agent import get_agent\n",
|
332 |
"\n",
|
333 |
"benchmarks = [\n",
|
334 |
+
" (get_agent(model_name=\"meta-llama/Meta-Llama-3.1-8B-Instruct\"), \"baseline\"),\n",
|
335 |
"]\n",
|
336 |
"\n",
|
337 |
"for agent, name in tqdm(benchmarks):\n",
|
338 |
+
" benchmark_agent(agent, dfSample, name)"
|
339 |
]
|
340 |
},
|
341 |
{
|