ric9176 commited on
Commit
241f177
·
1 Parent(s): 5e09d50

Basic working implementation of cross thread memory

Browse files
.langgraph_api/.langgraph_checkpoint.1.pckl CHANGED
Binary files a/.langgraph_api/.langgraph_checkpoint.1.pckl and b/.langgraph_api/.langgraph_checkpoint.1.pckl differ
 
.langgraph_api/.langgraph_checkpoint.2.pckl CHANGED
Binary files a/.langgraph_api/.langgraph_checkpoint.2.pckl and b/.langgraph_api/.langgraph_checkpoint.2.pckl differ
 
.langgraph_api/.langgraph_ops.pckl CHANGED
Binary files a/.langgraph_api/.langgraph_ops.pckl and b/.langgraph_api/.langgraph_ops.pckl differ
 
.langgraph_api/.langgraph_retry_counter.pckl CHANGED
Binary files a/.langgraph_api/.langgraph_retry_counter.pckl and b/.langgraph_api/.langgraph_retry_counter.pckl differ
 
.langgraph_api/store.pckl CHANGED
Binary files a/.langgraph_api/store.pckl and b/.langgraph_api/store.pckl differ
 
.langgraph_api/store.vectors.pckl CHANGED
Binary files a/.langgraph_api/store.vectors.pckl and b/.langgraph_api/store.vectors.pckl differ
 
agent/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from agent.graph import create_agent_graph, create_agent_graph_without_memory, get_checkpointer
2
 
3
- __all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer"]
 
1
+ from agent.graph import create_agent_graph, create_agent_graph_without_memory, get_checkpointer, langgraph_studio_graph
2
 
3
+ __all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer", "langgraph_studio_graph"]
agent/graph.py CHANGED
@@ -1,29 +1,54 @@
1
- from langgraph.graph import StateGraph
2
  from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
 
3
  import aiosqlite
4
  from types import TracebackType
5
  from typing import Optional, Type
6
 
7
  from agent.utils.state import AgentState
8
- from agent.utils.nodes import call_model, tool_node, should_continue
 
 
 
 
 
 
9
 
10
  def create_graph_builder():
11
  """Create a base graph builder with nodes and edges configured."""
12
- # Create the graph
13
  builder = StateGraph(AgentState)
 
14
 
15
  # Add nodes
16
  builder.add_node("agent", call_model)
17
  builder.add_node("action", tool_node)
 
 
18
 
19
- # Update edges
20
  builder.set_entry_point("agent")
 
 
 
 
 
21
  builder.add_conditional_edges(
22
  "agent",
23
  should_continue,
 
 
 
 
 
 
24
  )
 
 
25
  builder.add_edge("action", "agent")
26
 
 
 
 
27
  return builder
28
 
29
  def create_agent_graph_without_memory():
@@ -59,11 +84,21 @@ def get_checkpointer(db_path: str = "data/short_term.db") -> SQLiteCheckpointer:
59
  """Create and return a SQLiteCheckpointer instance."""
60
  return SQLiteCheckpointer(db_path)
61
 
 
 
 
62
  async def create_agent_graph(checkpointer: AsyncSqliteSaver):
63
- """Create an agent graph with SQLite-based memory persistence."""
64
  builder = create_graph_builder()
65
- graph = builder.compile(checkpointer=checkpointer)
 
 
 
 
 
66
  return graph
67
 
 
 
68
  # Export the graph builder functions
69
  __all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer"]
 
1
+ from langgraph.graph import StateGraph, END
2
  from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
3
+ from langgraph.store.memory import InMemoryStore
4
  import aiosqlite
5
  from types import TracebackType
6
  from typing import Optional, Type
7
 
8
  from agent.utils.state import AgentState
9
+ from agent.utils.nodes import (
10
+ call_model,
11
+ tool_node,
12
+ read_memory,
13
+ write_memory,
14
+ should_continue
15
+ )
16
 
17
  def create_graph_builder():
18
  """Create a base graph builder with nodes and edges configured."""
 
19
  builder = StateGraph(AgentState)
20
+
21
 
22
  # Add nodes
23
  builder.add_node("agent", call_model)
24
  builder.add_node("action", tool_node)
25
+ builder.add_node("read_memory", read_memory)
26
+ builder.add_node("write_memory", write_memory)
27
 
28
+ # Set entry point
29
  builder.set_entry_point("agent")
30
+
31
+ builder.add_edge("agent", "write_memory")
32
+ builder.add_edge("write_memory", END)
33
+
34
+ # Add conditional edges from agent
35
  builder.add_conditional_edges(
36
  "agent",
37
  should_continue,
38
+ {
39
+ "action": "action",
40
+ "read_memory": "read_memory",
41
+ "write_memory": "write_memory",
42
+ END: END
43
+ }
44
  )
45
+
46
+ # Connect action back to agent
47
  builder.add_edge("action", "agent")
48
 
49
+ # Memory operations should end after completion
50
+ builder.add_edge("read_memory", "agent")
51
+
52
  return builder
53
 
54
  def create_agent_graph_without_memory():
 
84
  """Create and return a SQLiteCheckpointer instance."""
85
  return SQLiteCheckpointer(db_path)
86
 
87
+ # Initialize store for across-thread memory
88
+ across_thread_memory = InMemoryStore()
89
+
90
  async def create_agent_graph(checkpointer: AsyncSqliteSaver):
91
+ """Create an agent graph with memory persistence."""
92
  builder = create_graph_builder()
93
+ # Compile with both SQLite checkpointer for within-thread memory
94
+ # and InMemoryStore for across-thread memory
95
+ graph = builder.compile(
96
+ checkpointer=checkpointer,
97
+ store=across_thread_memory
98
+ )
99
  return graph
100
 
101
+ langgraph_studio_graph = create_agent_graph_without_memory()
102
+
103
  # Export the graph builder functions
104
  __all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer"]
agent/utils/nodes.py CHANGED
@@ -1,21 +1,30 @@
1
  from langchain_openai import ChatOpenAI
2
- from langchain_core.messages import SystemMessage
3
  from langgraph.graph import END
4
  from langgraph.prebuilt import ToolNode
 
 
 
 
 
 
5
 
6
  from agent.utils.tools import tool_belt
7
  from agent.utils.state import AgentState
8
 
9
- # Initialize LLM
10
- llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
11
- model = llm.bind_tools(tool_belt)
 
 
 
12
 
13
- # Define system prompt
14
- SYSTEM_PROMPT = SystemMessage(content="""You are a Chief Joy Officer, an AI assistant focused on helping people find fun and enriching activities in London.
15
 
16
  Your core objectives are to:
17
  1. Understand and remember user preferences and interests
18
- 2. Provide personalized activity recommendations
19
  3. Be engaging and enthusiastic while maintaining professionalism
20
  4. Give clear, actionable suggestions
21
 
@@ -23,35 +32,210 @@ Key tools at your disposal:
23
  - retrieve_context: For finding specific information about events and activities
24
  - tavily_search: For general web searches about London activities
25
 
26
- Always aim to provide value while being mindful of the user's time and interests.""")
27
 
28
- # Define memory prompt
29
- MEMORY_PROMPT = """Here is the conversation history and relevant information about the user:
30
 
 
31
  {memory}
32
 
33
- Please use this context to provide more personalized responses. When appropriate, reference past interactions and demonstrated preferences to make your suggestions more relevant.
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- Remember to:
36
- 1. Acknowledge previously mentioned interests
37
- 2. Build upon past recommendations
38
- 3. Avoid repeating suggestions already discussed
39
- 4. Note any changes in preferences
40
 
41
- Current conversation:
42
- {conversation}"""
43
 
44
- def call_model(state: AgentState):
45
- messages = [SYSTEM_PROMPT] + state["messages"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  response = model.invoke(messages)
47
  return {"messages": [response]}
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Initialize tool node
50
  tool_node = ToolNode(tool_belt)
51
 
52
- # Simple flow control - always go to final
53
- def should_continue(state):
 
 
 
 
 
 
 
 
 
 
54
  last_message = state["messages"][-1]
55
- if last_message.tool_calls:
56
- return "action"
57
- return END
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langchain_openai import ChatOpenAI
2
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
3
  from langgraph.graph import END
4
  from langgraph.prebuilt import ToolNode
5
+ from langchain.memory import ConversationBufferMemory
6
+ from langchain_core.runnables.config import RunnableConfig
7
+ from langgraph.store.base import BaseStore
8
+ from typing import Literal
9
+ # from chainlit.logger import logger
10
+
11
 
12
  from agent.utils.tools import tool_belt
13
  from agent.utils.state import AgentState
14
 
15
+ # Initialize LLM for memory operations
16
+ model = ChatOpenAI(model="gpt-4", temperature=0)
17
+
18
+ # Define system prompt with memory
19
+ SYSTEM_PROMPT = """You are a Chief Joy Officer, an AI assistant focused on helping people find fun and enriching activities in London.
20
+ You have access to memory about the user's preferences and past interactions.
21
 
22
+ Here is what you remember about this user:
23
+ {memory}
24
 
25
  Your core objectives are to:
26
  1. Understand and remember user preferences and interests
27
+ 2. Provide personalized activity recommendations based on their interests
28
  3. Be engaging and enthusiastic while maintaining professionalism
29
  4. Give clear, actionable suggestions
30
 
 
32
  - retrieve_context: For finding specific information about events and activities
33
  - tavily_search: For general web searches about London activities
34
 
35
+ Always aim to provide value while being mindful of the user's time and interests."""
36
 
37
+ # Define memory creation/update prompt
38
+ MEMORY_UPDATE_PROMPT = """You are analyzing the conversation to update the user's profile and preferences.
39
 
40
+ CURRENT USER INFORMATION:
41
  {memory}
42
 
43
+ INSTRUCTIONS:
44
+ 1. Review the chat history carefully
45
+ 2. Identify new information about the user, such as:
46
+ - Activity preferences (indoor/outdoor, cultural/sports, etc.)
47
+ - Specific interests (art, music, food, etc.)
48
+ - Location preferences in London
49
+ - Time/schedule constraints
50
+ - Past experiences with activities
51
+ - Budget considerations
52
+ 3. Merge new information with existing memory
53
+ 4. Format as a clear, bulleted list
54
+ 5. If new information conflicts with existing memory, keep the most recent
55
 
56
+ Remember: Only include factual information directly stated by the user. Do not make assumptions.
 
 
 
 
57
 
58
+ Based on the conversation, please update the user information:"""
 
59
 
60
+ def get_last_human_message(state: AgentState):
61
+ """Get the last human message from the state."""
62
+ for message in reversed(state["messages"]):
63
+ if isinstance(message, HumanMessage):
64
+ return message
65
+ return None
66
+
67
+ def call_model(state: AgentState, config: RunnableConfig, store: BaseStore):
68
+ """Process messages using memory from the store."""
69
+ # Get the user ID from the config
70
+ user_id = config["configurable"].get("session_id", "default")
71
+
72
+ # Retrieve memory from the store
73
+ namespace = ("memory", user_id)
74
+ existing_memory = store.get(namespace, "user_memory")
75
+
76
+ # Extract memory content or use default
77
+ memory_content = existing_memory.value.get('memory') if existing_memory else "No previous information about this user."
78
+
79
+ # Create messages list with system prompt including memory
80
+ messages = [
81
+ SystemMessage(content=SYSTEM_PROMPT.format(memory=memory_content))
82
+ ] + state["messages"]
83
+
84
  response = model.invoke(messages)
85
  return {"messages": [response]}
86
 
87
+ def update_memory(state: AgentState, config: RunnableConfig, store: BaseStore):
88
+ """Update user memory based on conversation."""
89
+ user_id = config["configurable"].get("session_id", "default")
90
+ namespace = ("memory", user_id)
91
+ existing_memory = store.get(namespace, "user_memory")
92
+
93
+ memory_content = existing_memory.value.get('memory') if existing_memory else "No previous information about this user."
94
+
95
+ update_prompt = MEMORY_UPDATE_PROMPT.format(memory=memory_content)
96
+ new_memory = model.invoke([
97
+ SystemMessage(content=update_prompt)
98
+ ] + state["messages"])
99
+
100
+ store.put(namespace, "user_memory", {"memory": new_memory.content})
101
+ return state
102
+
103
+ def should_continue(state: AgentState) -> Literal["action", "read_memory", "write_memory", END]:
104
+ """Determine the next node in the graph."""
105
+ if not state["messages"]:
106
+ return END
107
+
108
+ last_message = state["messages"][-1]
109
+ if isinstance(last_message, list):
110
+ last_message = last_message[-1]
111
+
112
+ last_human_message = get_last_human_message(state)
113
+
114
+ # Handle tool calls
115
+ if hasattr(last_message, "additional_kwargs") and last_message.additional_kwargs.get("tool_calls"):
116
+ return "action"
117
+
118
+ # Handle memory operations for human messages
119
+ if last_human_message:
120
+
121
+ # Write memory for longer messages that might contain personal information
122
+ if len(last_human_message.content.split()) > 3:
123
+ return "write_memory"
124
+ # Read memory for short queries to ensure personalized responses
125
+ else:
126
+ return "read_memory"
127
+
128
+ return END
129
+
130
+ def read_memory(state: AgentState, config: RunnableConfig, store: BaseStore):
131
+ """Read and apply memory context without updating it."""
132
+ user_id = config["configurable"].get("session_id", "default")
133
+ namespace = ("memory", user_id)
134
+ existing_memory = store.get(namespace, "user_memory")
135
+
136
+ if existing_memory:
137
+ memory_content = existing_memory.value.get('memory')
138
+ # Add memory context to state for next model call
139
+ state["memory_context"] = memory_content
140
+
141
+ return state
142
+
143
+ # Define the memory creation prompt
144
+ MEMORY_CREATION_PROMPT = """"You are collecting information about the user to personalize your responses.
145
+
146
+ CURRENT USER INFORMATION:
147
+ {memory}
148
+
149
+ INSTRUCTIONS:
150
+ 1. Review the chat history below carefully
151
+ 2. Identify new information about the user, such as:
152
+ - Personal details (name, location)
153
+ - Preferences (likes, dislikes)
154
+ - Interests and hobbies
155
+ - Past experiences
156
+ - Goals or future plans
157
+ 3. Merge any new information with existing memory
158
+ 4. Format the memory as a clear, bulleted list
159
+ 5. If new information conflicts with existing memory, keep the most recent version
160
+
161
+ Remember: Only include factual information directly stated by the user. Do not make assumptions or inferences.
162
+
163
+ Based on the chat history below, please update the user information:"""
164
+
165
+ async def write_memory(state: AgentState, config: RunnableConfig, store: BaseStore) -> AgentState:
166
+ """Reflect on the chat history and save a memory to the store."""
167
+
168
+ # Get the session ID from config
169
+ session_id = config["configurable"].get("session_id", "default")
170
+
171
+ # Define the namespace for this user's memory
172
+ namespace = ("memory", session_id)
173
+
174
+ # Get existing memory using async interface
175
+ existing_memory = await store.aget(namespace, "user_memory")
176
+ memory_content = existing_memory.value.get('memory') if existing_memory else "No previous information about this user."
177
+
178
+ # Create system message with memory context
179
+ system_msg = SystemMessage(content=MEMORY_CREATION_PROMPT.format(memory=memory_content))
180
+
181
+ # Get messages and ensure we're working with the correct format
182
+ messages = state.get("messages", [])
183
+ if not messages:
184
+ return state
185
+
186
+ # Create memory using the model
187
+ new_memory = await model.ainvoke([system_msg] + messages)
188
+
189
+ # Store the updated memory using async interface
190
+ await store.aput(namespace, "user_memory", {"memory": new_memory.content})
191
+
192
+
193
+ return state
194
+
195
  # Initialize tool node
196
  tool_node = ToolNode(tool_belt)
197
 
198
+ def should_call_memory(state: AgentState) -> Literal["update_memory", "end"]:
199
+ """
200
+ Determine if we should update memory based on the conversation state.
201
+
202
+ Rules for updating memory:
203
+ 1. Only update after human messages (not tool responses)
204
+ 2. Update if the message might contain personal information
205
+ 3. Don't update for simple queries or acknowledgments
206
+ """
207
+ if not state["messages"]:
208
+ return "end"
209
+
210
  last_message = state["messages"][-1]
211
+
212
+ # Skip memory update for tool calls
213
+ if hasattr(last_message, "additional_kwargs") and last_message.additional_kwargs.get("tool_calls"):
214
+ return "agent"
215
+
216
+ # Skip memory update for very short messages (likely acknowledgments)
217
+ if isinstance(last_message, HumanMessage) and len(last_message.content.split()) <= 3:
218
+ return "agent"
219
+
220
+ # Update memory for human messages that might contain personal information
221
+ if isinstance(last_message, HumanMessage):
222
+ return "update_memory"
223
+
224
+ return "agent"
225
+
226
+ # def route_message(state: MessagesState, config: RunnableConfig, store: BaseStore) -> Literal[END, "update_todos", "update_instructions", "update_profile"]:
227
+
228
+ # """Reflect on the memories and chat history to decide whether to update the memory collection."""
229
+ # message = state['messages'][-1]
230
+ # if len(message.tool_calls) ==0:
231
+ # return END
232
+ # else:
233
+ # tool_call = message.tool_calls[0]
234
+ # if tool_call['args']['update_type'] == "user":
235
+ # return "update_profile"
236
+ # elif tool_call['args']['update_type'] == "todo":
237
+ # return "update_todos"
238
+ # elif tool_call['args']['update_type'] == "instructions":
239
+ # return "update_instructions"
240
+ # else:
241
+ # raise ValueError
agent/utils/state.py CHANGED
@@ -3,4 +3,4 @@ from langgraph.graph.message import add_messages
3
 
4
  class AgentState(TypedDict):
5
  messages: Annotated[list, add_messages]
6
- context: list # Store retrieved context
 
3
 
4
  class AgentState(TypedDict):
5
  messages: Annotated[list, add_messages]
6
+ context: list # Store retrieved context
app.py CHANGED
@@ -26,7 +26,8 @@ async def on_chat_start():
26
  config = RunnableConfig(
27
  configurable={
28
  "thread_id": session_id,
29
- "sessionId": session_id
 
30
  }
31
  )
32
 
@@ -34,10 +35,18 @@ async def on_chat_start():
34
  try:
35
  async with get_checkpointer(SHORT_TERM_MEMORY_DB_PATH) as saver:
36
  graph = await create_agent_graph(saver)
37
- initial_state = AgentState(messages=[], context=[])
 
 
 
 
38
  await graph.ainvoke(initial_state, config=config)
39
- # Store initial state as a serializable dict
40
- cl.user_session.set("last_state", {"messages": [], "context": []})
 
 
 
 
41
  except Exception as e:
42
  print(f"Error initializing state: {str(e)}")
43
 
@@ -67,14 +76,14 @@ async def on_message(message: cl.Message):
67
  config = RunnableConfig(
68
  configurable={
69
  "thread_id": session_id,
70
- "checkpoint_ns": session_id, # Use session_id as namespace
71
- "sessionId": session_id
72
  }
73
  )
74
 
75
  try:
76
  async with get_checkpointer(SHORT_TERM_MEMORY_DB_PATH) as saver:
77
- # Create graph with SQLite memory
78
  graph = await create_agent_graph(saver)
79
 
80
  # Get the last state or create new one
 
26
  config = RunnableConfig(
27
  configurable={
28
  "thread_id": session_id,
29
+ "session_id": session_id,
30
+ "checkpoint_ns": session_id
31
  }
32
  )
33
 
 
35
  try:
36
  async with get_checkpointer(SHORT_TERM_MEMORY_DB_PATH) as saver:
37
  graph = await create_agent_graph(saver)
38
+ initial_state = AgentState(
39
+ messages=[],
40
+ context=[]
41
+ )
42
+
43
  await graph.ainvoke(initial_state, config=config)
44
+
45
+ # Store initial state
46
+ cl.user_session.set("last_state", {
47
+ "messages": [],
48
+ "context": []
49
+ })
50
  except Exception as e:
51
  print(f"Error initializing state: {str(e)}")
52
 
 
76
  config = RunnableConfig(
77
  configurable={
78
  "thread_id": session_id,
79
+ "session_id": session_id,
80
+ "checkpoint_ns": session_id
81
  }
82
  )
83
 
84
  try:
85
  async with get_checkpointer(SHORT_TERM_MEMORY_DB_PATH) as saver:
86
+ # Create graph with memory
87
  graph = await create_agent_graph(saver)
88
 
89
  # Get the last state or create new one
data/short_term.db-shm ADDED
Binary file (32.8 kB). View file
 
example.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "role": "user",
3
+ "content": "I like sailing a lot, tell me about some activities I can do and remember this fact about me"
4
+ }
langgraph.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "dependencies": ".",
3
  "graphs": {
4
- "agent": "agent:graph"
5
  },
6
  "env": ".env"
7
  }
 
1
  {
2
  "dependencies": ".",
3
  "graphs": {
4
+ "agent": "agent:langgraph_studio_graph"
5
  },
6
  "env": ".env"
7
  }
tools.py DELETED
@@ -1,28 +0,0 @@
1
- from langchain_core.tools import tool
2
- from langchain_community.tools.tavily_search import TavilySearchResults
3
- from rag import create_rag_pipeline, add_urls_to_vectorstore
4
-
5
- # Initialize RAG pipeline
6
- rag_components = create_rag_pipeline(collection_name="london_events")
7
-
8
- # Add some initial URLs to the vector store
9
- urls = [
10
- "https://www.timeout.com/london/things-to-do-in-london-this-weekend",
11
- "https://www.timeout.com/london/london-events-in-march"
12
- ]
13
- add_urls_to_vectorstore(
14
- rag_components["vector_store"],
15
- rag_components["text_splitter"],
16
- urls
17
- )
18
-
19
- @tool
20
- def retrieve_context(query: str) -> list[str]:
21
- """Searches the knowledge base for relevant information about events and activities. Use this when you need specific details about events."""
22
- return [doc.page_content for doc in rag_components["retriever"].get_relevant_documents(query)]
23
-
24
- # Initialize Tavily search tool
25
- tavily_tool = TavilySearchResults(max_results=5)
26
-
27
- # Create tool belt
28
- tool_belt = [tavily_tool, retrieve_context]