unit-four-final-project / functions /agent_helper_functions.py
gperdrizet's picture
Switched to single agent powered by GPT-4.1, added step wait function to avoid hitting the OpenAI API rate limit.
b4e2809 verified
raw
history blame
4.65 kB
'''Helper functions for the agent(s) in the GAIA question answering system.'''
import os
import time
import json
import logging
from openai import OpenAI
from smolagents import CodeAgent, ActionStep, MessageRole
from configuration import CHECK_MODEL, TOKEN_LIMITER, STEP_WAIT
# Get logger for this module
logger = logging.getLogger(__name__)
def check_reasoning(final_answer:str, agent_memory):
"""Checks the reasoning and plot of the agent's final answer."""
prompt = (
f"Here is a user-given task and the agent steps: {agent_memory.get_succinct_steps()}. " +
"Please check that the reasoning process and answer are correct. " +
"Do they correctly answer the given task? " +
"First list reasons why yes/no, then write your final decision: " +
"PASS in caps lock if it is satisfactory, FAIL if it is not. " +
f"Final answer: {str(final_answer)}"
)
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
}
],
}
]
output = CHECK_MODEL(messages).content
print("Feedback: ", output)
if "FAIL" in output:
raise Exception(output) # pylint:disable=broad-exception-raised
return True
def step_memory_cap(memory_step: ActionStep, agent: CodeAgent) -> None:
'''Removes old steps from agent memory to keep context length under control.'''
task_step = agent.memory.steps[0]
planning_step = agent.memory.steps[1]
latest_step = agent.memory.steps[-1]
if len(agent.memory.steps) > 2:
agent.memory.steps = [task_step, planning_step, latest_step]
logger.info('Agent memory has %d steps', len(agent.memory.steps))
logger.info('Latest step is step %d', memory_step.step_number)
logger.info('Contains: %s messages', len(agent.memory.steps[-1].model_input_messages))
logger.info('Token usage: %s', agent.memory.steps[-1].token_usage.total_tokens)
for message in agent.memory.steps[-1].model_input_messages:
logger.debug(' Role: %s: %s', message['role'], message['content'][:100])
token_usage = agent.memory.steps[-1].token_usage.total_tokens
if token_usage > TOKEN_LIMITER:
logger.info('Token usage is %d, summarizing old messages', token_usage)
summary = summarize_old_messages(
agent.memory.steps[-1].model_input_messages[1:]
)
if summary is not None:
new_messages = [agent.memory.steps[-1].model_input_messages[0]]
new_messages.append({
'role': MessageRole.USER,
'content': [{
'type': 'text',
'text': f'Here is a summary of your investigation so far: {summary}'
}]
})
agent.memory.steps = [agent.memory.steps[0]]
agent.memory.steps[0].model_input_messages = new_messages
for message in agent.memory.steps[0].model_input_messages:
logger.debug(' Role: %s: %s', message['role'], message['content'][:100])
def summarize_old_messages(messages: dict) -> dict:
'''Summarizes old messages to keep context length under control.'''
client = OpenAI(api_key=os.environ['MODAL_API_KEY'])
client.base_url = (
'https://gperdrizet--vllm-openai-compatible-summarization-serve.modal.run/v1'
)
# Default to first avalible model
model = client.models.list().data[0]
model_id = model.id
messages = [
{
'role': 'system',
'content': ('Summarize the following interaction between an AI agent and a user.' +
f'Return the summary formatted as text, not as JSON: {json.dumps(messages)}')
}
]
completion_args = {
'model': model_id,
'messages': messages,
}
try:
response = client.chat.completions.create(**completion_args)
except Exception as e: # pylint: disable=broad-exception-caught
response = None
logger.error('Error during Modal API call: %s', e)
if response is not None:
summary = response.choices[0].message.content
else:
summary = None
return summary
def step_wait(memory_step: ActionStep, agent: CodeAgent) -> None:
'''Waits for a while to prevent hitting API rate limits.'''
logger.info('Waiting for %d seconds to prevent hitting API rate limits', STEP_WAIT)
logger.info('Current step is %d', memory_step.step_number)
logger.info('Current agent has %d steps', len(agent.memory.steps))
time.sleep(STEP_WAIT)
return True