from langgraph.func import task from langchain_core.messages import ( SystemMessage, ToolMessage, HumanMessage, BaseMessage, ) from langchain_core.runnables import RunnableConfig from .memory_client import memory_client from .tools import search_memories, search_vectorstore from .llms import main_model, output_formatter_model from .prompts import MAKHFI_AI_PROMPT, OUTPUT_FORMATTER_PROMPT from .schemas import OutputFormat tools = [search_memories, search_vectorstore] tools_by_name = {tool.name: tool for tool in tools} agent_with_tools = main_model.bind_tools(tools) output_formatter = output_formatter_model.with_structured_output(OutputFormat) @task def call_model(messages: list[BaseMessage], memories: str): """Call model with a sequence of messages.""" response = agent_with_tools.invoke( [ SystemMessage( content=MAKHFI_AI_PROMPT.format(memories=memories) ) ] + messages ) return response @task def get_structued_output(agent_response: str) -> OutputFormat: response: OutputFormat = output_formatter.invoke( [ SystemMessage(content=OUTPUT_FORMATTER_PROMPT), HumanMessage(content=agent_response), ] ) return response @task def manage_memories(user_message: str, config: RunnableConfig): """Handles memories operations""" user_id = config.get("configurable", {}).get("user_id") message = [{"role": "user", "content": user_message}] memories = memory_client.add( message, user_id=user_id, version="v2", output_format="v1.1" ) return memories @task def get_recent_memories(config: RunnableConfig): """Retrieve the most recent user memories (max 10) based on `updated_at` timestamp. """ user_id = config.get("configurable", {}).get("user_id") if not user_id: raise ValueError("User Id not found in config") memories = memory_client.get_all(version="v2", filters={"user_id": user_id}) if not memories: return "(No information is available about the user yet)" # Sort memories by `updated_at` in descending order sorted_memories = sorted(memories, key=lambda m: m["updated_at"], reverse=True) recent_memories = sorted_memories[:10] return "\n".join(f"- {mem['memory']}" for mem in recent_memories) @task def call_tool(tool_call, config: RunnableConfig): tool = tools_by_name[tool_call["name"]] observation = tool.invoke(tool_call["args"]) return ToolMessage(content=observation, tool_call_id=tool_call["id"])