|
|
from langgraph.func import task |
|
|
from langchain_core.messages import ( |
|
|
SystemMessage, |
|
|
ToolMessage, |
|
|
HumanMessage, |
|
|
BaseMessage, |
|
|
) |
|
|
from langchain_core.runnables import RunnableConfig |
|
|
|
|
|
from .memory_client import memory_client |
|
|
from .tools import search_memories, search_vectorstore |
|
|
from .llms import main_model, output_formatter_model |
|
|
from .prompts import MAKHFI_AI_PROMPT, OUTPUT_FORMATTER_PROMPT |
|
|
from .schemas import OutputFormat |
|
|
|
|
|
tools = [search_memories, search_vectorstore] |
|
|
tools_by_name = {tool.name: tool for tool in tools} |
|
|
|
|
|
agent_with_tools = main_model.bind_tools(tools) |
|
|
output_formatter = output_formatter_model.with_structured_output(OutputFormat) |
|
|
|
|
|
@task |
|
|
def call_model(messages: list[BaseMessage], memories: str): |
|
|
"""Call model with a sequence of messages.""" |
|
|
response = agent_with_tools.invoke( |
|
|
[ |
|
|
SystemMessage( |
|
|
content=MAKHFI_AI_PROMPT.format(memories=memories) |
|
|
) |
|
|
] + messages |
|
|
) |
|
|
return response |
|
|
|
|
|
@task |
|
|
def get_structued_output(agent_response: str) -> OutputFormat: |
|
|
response: OutputFormat = output_formatter.invoke( |
|
|
[ |
|
|
SystemMessage(content=OUTPUT_FORMATTER_PROMPT), |
|
|
HumanMessage(content=agent_response), |
|
|
] |
|
|
) |
|
|
return response |
|
|
|
|
|
@task |
|
|
def manage_memories(user_message: str, config: RunnableConfig): |
|
|
"""Handles memories operations""" |
|
|
user_id = config.get("configurable", {}).get("user_id") |
|
|
|
|
|
message = [{"role": "user", "content": user_message}] |
|
|
memories = memory_client.add( |
|
|
message, user_id=user_id, version="v2", output_format="v1.1" |
|
|
) |
|
|
return memories |
|
|
|
|
|
@task |
|
|
def get_recent_memories(config: RunnableConfig): |
|
|
"""Retrieve the most recent user memories (max 10) based on `updated_at` timestamp. """ |
|
|
user_id = config.get("configurable", {}).get("user_id") |
|
|
|
|
|
if not user_id: |
|
|
raise ValueError("User Id not found in config") |
|
|
|
|
|
memories = memory_client.get_all(version="v2", filters={"user_id": user_id}) |
|
|
|
|
|
if not memories: |
|
|
return "(No information is available about the user yet)" |
|
|
|
|
|
|
|
|
sorted_memories = sorted(memories, key=lambda m: m["updated_at"], reverse=True) |
|
|
recent_memories = sorted_memories[:10] |
|
|
|
|
|
return "\n".join(f"- {mem['memory']}" for mem in recent_memories) |
|
|
|
|
|
@task |
|
|
def call_tool(tool_call, config: RunnableConfig): |
|
|
tool = tools_by_name[tool_call["name"]] |
|
|
observation = tool.invoke(tool_call["args"]) |
|
|
return ToolMessage(content=observation, tool_call_id=tool_call["id"]) |
|
|
|