|
|
from langgraph.graph.message import add_messages |
|
|
from langgraph.func import entrypoint |
|
|
from langchain_core.messages import BaseMessage |
|
|
from langchain_core.runnables import RunnableConfig |
|
|
|
|
|
from .tasks import call_model, get_structued_output, manage_memories, call_tool, get_recent_memories |
|
|
|
|
|
|
|
|
@entrypoint() |
|
|
def agent(messages: list[BaseMessage], config: RunnableConfig): |
|
|
tool_calls_count = 0 |
|
|
tool_names_called = set() |
|
|
answer = None |
|
|
links = [] |
|
|
|
|
|
|
|
|
recent_memories = get_recent_memories().result() |
|
|
|
|
|
llm_future = call_model(messages, memories=recent_memories) |
|
|
memories_future = manage_memories(messages[-1].content) |
|
|
|
|
|
|
|
|
llm_response = llm_future.result() |
|
|
memories = memories_future.result() |
|
|
|
|
|
while True: |
|
|
if not llm_response.tool_calls: |
|
|
break |
|
|
|
|
|
|
|
|
tool_results_future = [] |
|
|
for tool_call in llm_response.tool_calls: |
|
|
tool_names_called.add(tool_call["name"]) |
|
|
tool_results_future.append(call_tool(tool_call)) |
|
|
|
|
|
tool_results = [fut.result() for fut in tool_results_future] |
|
|
tool_calls_count += len(tool_results) |
|
|
|
|
|
|
|
|
messages = add_messages(messages, [llm_response, *tool_results]) |
|
|
|
|
|
|
|
|
llm_response = call_model(messages, memories=recent_memories).result() |
|
|
|
|
|
|
|
|
other_tools_called = any(name != "search_memories" for name in tool_names_called) |
|
|
|
|
|
if tool_calls_count > 0 and other_tools_called: |
|
|
|
|
|
structured_output = ( |
|
|
get_structued_output(llm_response.content).result().model_dump() |
|
|
) |
|
|
answer = structured_output["text"] |
|
|
links = [str(link) for link in structured_output.get("links", [])] |
|
|
else: |
|
|
answer = llm_response.content |
|
|
|
|
|
return { |
|
|
"answer": answer, |
|
|
"links": links, |
|
|
"messages": messages + [llm_response] |
|
|
} |
|
|
|