Makhfi_AI / workflow /tasks.py
Aasher's picture
fix(call_model): add memories parameter to ensure memories are passed to agent
8477204
raw
history blame
2.57 kB
from langgraph.func import task
from langchain_core.messages import (
SystemMessage,
ToolMessage,
HumanMessage,
BaseMessage,
)
from langchain_core.runnables import RunnableConfig
from .memory_client import memory_client
from .tools import search_memories, search_vectorstore
from .llms import main_model, output_formatter_model
from .prompts import MAKHFI_AI_PROMPT, OUTPUT_FORMATTER_PROMPT
from .schemas import OutputFormat
tools = [search_memories, search_vectorstore]
tools_by_name = {tool.name: tool for tool in tools}
agent_with_tools = main_model.bind_tools(tools)
output_formatter = output_formatter_model.with_structured_output(OutputFormat)
@task
def call_model(messages: list[BaseMessage], memories: str):
"""Call model with a sequence of messages."""
response = agent_with_tools.invoke(
[
SystemMessage(
content=MAKHFI_AI_PROMPT.format(memories=memories)
)
] + messages
)
return response
@task
def get_structued_output(agent_response: str) -> OutputFormat:
response: OutputFormat = output_formatter.invoke(
[
SystemMessage(content=OUTPUT_FORMATTER_PROMPT),
HumanMessage(content=agent_response),
]
)
return response
@task
def manage_memories(user_message: str, config: RunnableConfig):
"""Handles memories operations"""
user_id = config.get("configurable", {}).get("user_id")
message = [{"role": "user", "content": user_message}]
memories = memory_client.add(
message, user_id=user_id, version="v2", output_format="v1.1"
)
return memories
@task
def get_recent_memories(config: RunnableConfig):
"""Retrieve the most recent user memories (max 10) based on `updated_at` timestamp. """
user_id = config.get("configurable", {}).get("user_id")
if not user_id:
raise ValueError("User Id not found in config")
memories = memory_client.get_all(version="v2", filters={"user_id": user_id})
if not memories:
return "(No information is available about the user yet)"
# Sort memories by `updated_at` in descending order
sorted_memories = sorted(memories, key=lambda m: m["updated_at"], reverse=True)
recent_memories = sorted_memories[:10]
return "\n".join(f"- {mem['memory']}" for mem in recent_memories)
@task
def call_tool(tool_call, config: RunnableConfig):
tool = tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
return ToolMessage(content=observation, tool_call_id=tool_call["id"])