File size: 4,645 Bytes
b4e2809
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
'''Helper functions for the agent(s) in the GAIA question answering system.'''

import os
import time
import json
import logging
from openai import OpenAI
from smolagents import CodeAgent, ActionStep, MessageRole
from configuration import CHECK_MODEL, TOKEN_LIMITER, STEP_WAIT

# Get logger for this module
logger = logging.getLogger(__name__)


def check_reasoning(final_answer:str, agent_memory):
    """Checks the reasoning and plot of the agent's final answer."""

    prompt = (
        f"Here is a user-given task and the agent steps: {agent_memory.get_succinct_steps()}. " +
        "Please check that the reasoning process and answer are correct. " +
        "Do they correctly answer the given task? " +
        "First list reasons why yes/no, then write your final decision: " +
        "PASS in caps lock if it is satisfactory, FAIL if it is not. " +
        f"Final answer: {str(final_answer)}"
    )

    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt,
                }
            ],
        }
    ]

    output = CHECK_MODEL(messages).content
    print("Feedback: ", output)

    if "FAIL" in output:
        raise Exception(output) # pylint:disable=broad-exception-raised

    return True


def step_memory_cap(memory_step: ActionStep, agent: CodeAgent) -> None:
    '''Removes old steps from agent memory to keep context length under control.'''

    task_step = agent.memory.steps[0]
    planning_step = agent.memory.steps[1]
    latest_step = agent.memory.steps[-1]

    if len(agent.memory.steps) > 2:
        agent.memory.steps = [task_step, planning_step, latest_step]

    logger.info('Agent memory has %d steps', len(agent.memory.steps))
    logger.info('Latest step is step %d', memory_step.step_number)
    logger.info('Contains: %s messages', len(agent.memory.steps[-1].model_input_messages))
    logger.info('Token usage: %s', agent.memory.steps[-1].token_usage.total_tokens)

    for message in agent.memory.steps[-1].model_input_messages:
        logger.debug(' Role: %s: %s', message['role'], message['content'][:100])

    token_usage = agent.memory.steps[-1].token_usage.total_tokens

    if token_usage > TOKEN_LIMITER:
        logger.info('Token usage is %d, summarizing old messages', token_usage)

        summary = summarize_old_messages(
            agent.memory.steps[-1].model_input_messages[1:]
        )

        if summary is not None:

            new_messages = [agent.memory.steps[-1].model_input_messages[0]]
            new_messages.append({
                'role': MessageRole.USER,
                'content': [{
                    'type': 'text',
                    'text': f'Here is a summary of your investigation so far: {summary}'
                }]
            })
            agent.memory.steps = [agent.memory.steps[0]]
            agent.memory.steps[0].model_input_messages = new_messages

        for message in agent.memory.steps[0].model_input_messages:
            logger.debug(' Role: %s: %s', message['role'], message['content'][:100])


def summarize_old_messages(messages: dict) -> dict:
    '''Summarizes old messages to keep context length under control.'''

    client = OpenAI(api_key=os.environ['MODAL_API_KEY'])

    client.base_url = (
        'https://gperdrizet--vllm-openai-compatible-summarization-serve.modal.run/v1'
    )

    # Default to first avalible model
    model = client.models.list().data[0]
    model_id = model.id

    messages = [
        {
            'role': 'system',
            'content': ('Summarize the following interaction between an AI agent and a user.' +
                f'Return the summary formatted as text, not as JSON: {json.dumps(messages)}')
        }
    ]

    completion_args = {
        'model': model_id,
        'messages': messages,
    }

    try:
        response = client.chat.completions.create(**completion_args)

    except Exception as e: # pylint: disable=broad-exception-caught
        response = None
        logger.error('Error during Modal API call: %s', e)

    if response is not None:
        summary = response.choices[0].message.content

    else:
        summary = None

    return summary

def step_wait(memory_step: ActionStep, agent: CodeAgent) -> None:
    '''Waits for a while to prevent hitting API rate limits.'''

    logger.info('Waiting for %d seconds to prevent hitting API rate limits', STEP_WAIT)
    logger.info('Current step is %d', memory_step.step_number)
    logger.info('Current agent has %d steps', len(agent.memory.steps))

    time.sleep(STEP_WAIT)

    return True