Upload folder using huggingface_hub
Browse files
agent.py
CHANGED
@@ -1,25 +1,52 @@
|
|
1 |
from transformers import ReactCodeAgent, HfApiEngine
|
2 |
from prompts import *
|
3 |
from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
|
|
|
|
|
4 |
|
5 |
-
DEFAULT_TASK_SOLVING_TOOLBOX = [SquadRetrieverTool(), SquadQueryTool()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def get_agent(
|
8 |
model_name=None,
|
9 |
system_prompt=DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT,
|
10 |
toolbox=DEFAULT_TASK_SOLVING_TOOLBOX,
|
|
|
|
|
11 |
):
|
12 |
DEFAULT_MODEL_NAME = "http://localhost:1234/v1"
|
13 |
if model_name is None:
|
14 |
model_name = DEFAULT_MODEL_NAME
|
15 |
|
16 |
-
llm_engine = HfApiEngine(model_name)
|
17 |
|
18 |
# Initialize the agent with both tools
|
19 |
agent = ReactCodeAgent(
|
20 |
tools=toolbox,
|
21 |
llm_engine=llm_engine,
|
22 |
system_prompt=system_prompt,
|
|
|
23 |
)
|
24 |
|
25 |
return agent
|
|
|
1 |
from transformers import ReactCodeAgent, HfApiEngine
|
2 |
from prompts import *
|
3 |
from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
|
4 |
+
from transformers.agents.llm_engine import MessageRole, get_clean_message_list
|
5 |
+
from openai import OpenAI
|
6 |
|
7 |
+
DEFAULT_TASK_SOLVING_TOOLBOX = [SquadRetrieverTool()] # , SquadQueryTool()
|
8 |
+
|
9 |
+
openai_role_conversions = {
|
10 |
+
MessageRole.TOOL_RESPONSE: MessageRole.USER,
|
11 |
+
}
|
12 |
+
|
13 |
+
class OpenAIModel:
|
14 |
+
def __init__(self, model_name="gpt-4o"):
|
15 |
+
self.model_name = model_name
|
16 |
+
self.client = OpenAI(
|
17 |
+
api_key=os.getenv("OPENAI_API_KEY"),
|
18 |
+
)
|
19 |
+
|
20 |
+
def __call__(self, messages, stop_sequences=[]):
|
21 |
+
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
|
22 |
+
|
23 |
+
response = self.client.chat.completions.create(
|
24 |
+
model=self.model_name,
|
25 |
+
messages=messages,
|
26 |
+
stop=stop_sequences,
|
27 |
+
temperature=0.5
|
28 |
+
)
|
29 |
+
return response.choices[0].message.content
|
30 |
|
31 |
def get_agent(
|
32 |
model_name=None,
|
33 |
system_prompt=DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT,
|
34 |
toolbox=DEFAULT_TASK_SOLVING_TOOLBOX,
|
35 |
+
use_openai=False,
|
36 |
+
openai_model_name="gpt-4o",
|
37 |
):
|
38 |
DEFAULT_MODEL_NAME = "http://localhost:1234/v1"
|
39 |
if model_name is None:
|
40 |
model_name = DEFAULT_MODEL_NAME
|
41 |
|
42 |
+
llm_engine = HfApiEngine(model_name) if not use_openai else OpenAIModel(openai_model_name)
|
43 |
|
44 |
# Initialize the agent with both tools
|
45 |
agent = ReactCodeAgent(
|
46 |
tools=toolbox,
|
47 |
llm_engine=llm_engine,
|
48 |
system_prompt=system_prompt,
|
49 |
+
additional_authorized_imports=["PIL"],
|
50 |
)
|
51 |
|
52 |
return agent
|