vonliechti commited on
Commit
e59a921
·
verified ·
1 Parent(s): 5ea745b

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. agent.py +29 -2
agent.py CHANGED
@@ -1,25 +1,52 @@
1
  from transformers import ReactCodeAgent, HfApiEngine
2
  from prompts import *
3
  from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
 
 
4
 
5
- DEFAULT_TASK_SOLVING_TOOLBOX = [SquadRetrieverTool(), SquadQueryTool()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def get_agent(
8
  model_name=None,
9
  system_prompt=DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT,
10
  toolbox=DEFAULT_TASK_SOLVING_TOOLBOX,
 
 
11
  ):
12
  DEFAULT_MODEL_NAME = "http://localhost:1234/v1"
13
  if model_name is None:
14
  model_name = DEFAULT_MODEL_NAME
15
 
16
- llm_engine = HfApiEngine(model_name)
17
 
18
  # Initialize the agent with both tools
19
  agent = ReactCodeAgent(
20
  tools=toolbox,
21
  llm_engine=llm_engine,
22
  system_prompt=system_prompt,
 
23
  )
24
 
25
  return agent
 
1
  from transformers import ReactCodeAgent, HfApiEngine
2
  from prompts import *
3
  from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
4
+ from transformers.agents.llm_engine import MessageRole, get_clean_message_list
5
+ from openai import OpenAI
6
 
7
+ DEFAULT_TASK_SOLVING_TOOLBOX = [SquadRetrieverTool()] # , SquadQueryTool()
8
+
9
+ openai_role_conversions = {
10
+ MessageRole.TOOL_RESPONSE: MessageRole.USER,
11
+ }
12
+
13
+ class OpenAIModel:
14
+ def __init__(self, model_name="gpt-4o"):
15
+ self.model_name = model_name
16
+ self.client = OpenAI(
17
+ api_key=os.getenv("OPENAI_API_KEY"),
18
+ )
19
+
20
+ def __call__(self, messages, stop_sequences=[]):
21
+ messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)
22
+
23
+ response = self.client.chat.completions.create(
24
+ model=self.model_name,
25
+ messages=messages,
26
+ stop=stop_sequences,
27
+ temperature=0.5
28
+ )
29
+ return response.choices[0].message.content
30
 
31
  def get_agent(
32
  model_name=None,
33
  system_prompt=DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT,
34
  toolbox=DEFAULT_TASK_SOLVING_TOOLBOX,
35
+ use_openai=False,
36
+ openai_model_name="gpt-4o",
37
  ):
38
  DEFAULT_MODEL_NAME = "http://localhost:1234/v1"
39
  if model_name is None:
40
  model_name = DEFAULT_MODEL_NAME
41
 
42
+ llm_engine = HfApiEngine(model_name) if not use_openai else OpenAIModel(openai_model_name)
43
 
44
  # Initialize the agent with both tools
45
  agent = ReactCodeAgent(
46
  tools=toolbox,
47
  llm_engine=llm_engine,
48
  system_prompt=system_prompt,
49
+ additional_authorized_imports=["PIL"],
50
  )
51
 
52
  return agent