|
from transformers import ReactCodeAgent, HfApiEngine |
|
from prompts import * |
|
from tools.squad_tools import SquadRetrieverTool, SquadQueryTool |
|
from transformers.agents.llm_engine import MessageRole, get_clean_message_list |
|
from openai import OpenAI |
|
from prompts import FOCUSED_SQUAD_REACT_CODE_SYSTEM_PROMPT |
|
|
|
DEFAULT_TASK_SOLVING_TOOLBOX = [SquadRetrieverTool()] |
|
|
|
openai_role_conversions = { |
|
MessageRole.TOOL_RESPONSE: MessageRole.USER, |
|
} |
|
|
|
class OpenAIModel: |
|
def __init__(self, model_name="gpt-4o-mini-2024-07-18"): |
|
self.model_name = model_name |
|
self.client = OpenAI( |
|
api_key=os.getenv("OPENAI_API_KEY"), |
|
) |
|
|
|
def __call__(self, messages, stop_sequences=[]): |
|
messages = get_clean_message_list(messages, role_conversions=openai_role_conversions) |
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model_name, |
|
messages=messages, |
|
stop=stop_sequences, |
|
temperature=0.5 |
|
) |
|
return response.choices[0].message.content |
|
|
|
def get_agent( |
|
model_name=None, |
|
system_prompt=FOCUSED_SQUAD_REACT_CODE_SYSTEM_PROMPT, |
|
toolbox=DEFAULT_TASK_SOLVING_TOOLBOX, |
|
use_openai=True, |
|
openai_model_name="gpt-4o-mini-2024-07-18", |
|
): |
|
DEFAULT_MODEL_NAME = "http://localhost:1234/v1" |
|
if model_name is None: |
|
model_name = DEFAULT_MODEL_NAME |
|
|
|
llm_engine = HfApiEngine(model_name) if not use_openai else OpenAIModel(openai_model_name) |
|
|
|
|
|
agent = ReactCodeAgent( |
|
tools=toolbox, |
|
llm_engine=llm_engine, |
|
system_prompt=system_prompt, |
|
additional_authorized_imports=["PIL"], |
|
) |
|
|
|
return agent |
|
|