Aasher commited on
Commit
dfd09d4
·
1 Parent(s): 38ccc4e

feat(memories): add get_recent_memories task and integrate recent memories into agent workflow

Browse files
Files changed (3) hide show
  1. workflow/agent.py +6 -3
  2. workflow/prompts.py +6 -3
  3. workflow/tasks.py +19 -0
workflow/agent.py CHANGED
@@ -3,7 +3,7 @@ from langgraph.func import entrypoint
3
  from langchain_core.messages import BaseMessage
4
  from langchain_core.runnables import RunnableConfig
5
 
6
- from .tasks import call_model, get_structued_output, manage_memories, call_tool
7
 
8
 
9
  @entrypoint()
@@ -12,7 +12,10 @@ def agent(messages: list[BaseMessage], config: RunnableConfig):
12
  answer = None
13
  links = []
14
 
15
- llm_future = call_model(messages)
 
 
 
16
  memories_future = manage_memories(messages[-1].content)
17
 
18
  # Now, wait for both to complete
@@ -34,7 +37,7 @@ def agent(messages: list[BaseMessage], config: RunnableConfig):
34
  messages = add_messages(messages, [llm_response, *tool_results])
35
 
36
  # Call model again
37
- llm_response = call_model(messages).result()
38
 
39
  if tool_calls_count > 0:
40
  # Structure the final output
 
3
  from langchain_core.messages import BaseMessage
4
  from langchain_core.runnables import RunnableConfig
5
 
6
+ from .tasks import call_model, get_structued_output, manage_memories, call_tool, get_recent_memories
7
 
8
 
9
  @entrypoint()
 
12
  answer = None
13
  links = []
14
 
15
+ # Fetch Recent User Information
16
+ recent_memories = get_recent_memories().result()
17
+
18
+ llm_future = call_model(messages, memories=recent_memories)
19
  memories_future = manage_memories(messages[-1].content)
20
 
21
  # Now, wait for both to complete
 
37
  messages = add_messages(messages, [llm_response, *tool_results])
38
 
39
  # Call model again
40
+ llm_response = call_model(messages, memories=recent_memories).result()
41
 
42
  if tool_calls_count > 0:
43
  # Structure the final output
workflow/prompts.py CHANGED
@@ -37,9 +37,12 @@ If you used `search_vectorstore` tool to retrieve documents, you should:
37
  **Note:** Do not provide redundant sources.
38
 
39
  ---
40
- ## Memories
41
- - When needed, use the `search_memories` tool to access information about the user's history and preferences
42
- - This tool provides context about past interactions and helps personalize responses
 
 
 
43
 
44
  ---
45
  ## Privacy
 
37
  **Note:** Do not provide redundant sources.
38
 
39
  ---
40
+ ## User Context
41
+ 1. You can use `search_memories` tool to retrieve stored user-related information such as personal details, hobbies, preferences, and background. This information evolves as the user continues to interact with you. Use it to personalize your responses.
42
+ 2. Current known information about the user:
43
+ {memories}
44
+
45
+ **Note:** DO NOT call `search_memories` tool if the section above indicates that no information is available yet. This usually means the user is new or hasn’t shared any information yet.
46
 
47
  ---
48
  ## Privacy
workflow/tasks.py CHANGED
@@ -48,6 +48,25 @@ def manage_memories(user_message: str, config: RunnableConfig):
48
  )
49
  return memories
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @task
52
  def call_tool(tool_call, config: RunnableConfig):
53
  tool = tools_by_name[tool_call["name"]]
 
48
  )
49
  return memories
50
 
51
+ @task
52
+ def get_recent_memories(config: RunnableConfig):
53
+ """Retrieve the most recent user memories (max 10) based on `updated_at` timestamp. """
54
+ user_id = config.get("configurable", {}).get("user_id")
55
+
56
+ if not user_id:
57
+ raise ValueError("User Id not found in config")
58
+
59
+ memories = memory_client.get_all(version="v2", filters={"user_id": user_id})
60
+
61
+ if not memories:
62
+ return "(No information is available about the user yet)"
63
+
64
+ # Sort memories by `updated_at` in descending order
65
+ sorted_memories = sorted(memories, key=lambda m: m["updated_at"], reverse=True)
66
+ recent_memories = sorted_memories[:10]
67
+
68
+ return "\n".join(f"- {mem['memory']}" for mem in recent_memories)
69
+
70
  @task
71
  def call_tool(tool_call, config: RunnableConfig):
72
  tool = tools_by_name[tool_call["name"]]