File size: 1,659 Bytes
879b924
5ea745b
879b924
e59a921
 
879b924
e59a921
 
 
 
 
 
 
800b8b7
e59a921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ea745b
 
 
 
 
e1ed8d0
800b8b7
5ea745b
622f963
 
 
879b924
e59a921
879b924
 
 
5ea745b
879b924
5ea745b
e59a921
879b924
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers import ReactCodeAgent, HfApiEngine
from prompts import *
from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
from transformers.agents.llm_engine import MessageRole, get_clean_message_list
from openai import OpenAI

DEFAULT_TASK_SOLVING_TOOLBOX = [SquadRetrieverTool()] # , SquadQueryTool()

openai_role_conversions = {
    MessageRole.TOOL_RESPONSE: MessageRole.USER,
}

class OpenAIModel:
    def __init__(self, model_name="gpt-4o-mini-2024-07-18"):
        self.model_name = model_name
        self.client = OpenAI(
            api_key=os.getenv("OPENAI_API_KEY"),
        )

    def __call__(self, messages, stop_sequences=[]):
        messages = get_clean_message_list(messages, role_conversions=openai_role_conversions)

        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            stop=stop_sequences,
            temperature=0.5
        )
        return response.choices[0].message.content

def get_agent(
    model_name=None,
    system_prompt=DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT,
    toolbox=DEFAULT_TASK_SOLVING_TOOLBOX,
    use_openai=True,
    openai_model_name="gpt-4o-mini-2024-07-18",
):
    DEFAULT_MODEL_NAME = "http://localhost:1234/v1"
    if model_name is None:
        model_name = DEFAULT_MODEL_NAME

    llm_engine = HfApiEngine(model_name) if not use_openai else OpenAIModel(openai_model_name)

    # Initialize the agent with both tools
    agent = ReactCodeAgent(
        tools=toolbox,
        llm_engine=llm_engine,
        system_prompt=system_prompt,
        additional_authorized_imports=["PIL"],
    )

    return agent