ric9176 commited on
Commit
28f3481
Β·
1 Parent(s): cdceb53

complete refactor to use AsyncSqliteSaver

Browse files
Files changed (4) hide show
  1. agent/__init__.py +2 -2
  2. agent/agent.py +0 -50
  3. agent/graph.py +69 -0
  4. app.py +95 -61
agent/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from agent.agent import graph, graph_with_memory
2
 
3
- __all__ = ["graph", "graph_with_memory"]
 
1
+ from agent.graph import create_agent_graph, create_agent_graph_without_memory, get_checkpointer
2
 
3
+ __all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer"]
agent/agent.py DELETED
@@ -1,50 +0,0 @@
1
- from langgraph.graph import StateGraph
2
- from langgraph.checkpoint.memory import MemorySaver
3
-
4
- from agent.utils.state import AgentState
5
- from agent.utils.nodes import call_model, tool_node, should_continue
6
-
7
- def create_agent_graph():
8
- # Create the graph
9
- builder = StateGraph(AgentState)
10
-
11
- # Add nodes
12
- builder.add_node("agent", call_model)
13
- builder.add_node("action", tool_node)
14
-
15
- # Update edges
16
- builder.set_entry_point("agent")
17
- builder.add_conditional_edges(
18
- "agent",
19
- should_continue,
20
- )
21
- builder.add_edge("action", "agent")
22
-
23
- # Initialize memory saver for conversation persistence
24
- memory = MemorySaver()
25
-
26
- # Compile the graph with memory
27
- return builder.compile(checkpointer=memory)
28
-
29
- def create_agent_graph_without_memory():
30
- # Create the graph
31
- builder = StateGraph(AgentState)
32
-
33
- # Add nodes
34
- builder.add_node("agent", call_model)
35
- builder.add_node("action", tool_node)
36
-
37
- # Update edges
38
- builder.set_entry_point("agent")
39
- builder.add_conditional_edges(
40
- "agent",
41
- should_continue,
42
- )
43
- builder.add_edge("action", "agent")
44
-
45
- # Compile the graph without memory
46
- return builder.compile()
47
-
48
- # Create both graph variants
49
- graph_with_memory = create_agent_graph()
50
- graph = create_agent_graph_without_memory()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agent/graph.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import StateGraph
2
+ from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
3
+ import aiosqlite
4
+ from types import TracebackType
5
+ from typing import Optional, Type
6
+
7
+ from agent.utils.state import AgentState
8
+ from agent.utils.nodes import call_model, tool_node, should_continue
9
+
10
+ def create_graph_builder():
11
+ """Create a base graph builder with nodes and edges configured."""
12
+ # Create the graph
13
+ builder = StateGraph(AgentState)
14
+
15
+ # Add nodes
16
+ builder.add_node("agent", call_model)
17
+ builder.add_node("action", tool_node)
18
+
19
+ # Update edges
20
+ builder.set_entry_point("agent")
21
+ builder.add_conditional_edges(
22
+ "agent",
23
+ should_continue,
24
+ )
25
+ builder.add_edge("action", "agent")
26
+
27
+ return builder
28
+
29
+ def create_agent_graph_without_memory():
30
+ """Create an agent graph without memory persistence."""
31
+ builder = create_graph_builder()
32
+ return builder.compile()
33
+
34
+ class SQLiteCheckpointer:
35
+ """Context manager for SQLite checkpointing."""
36
+
37
+ def __init__(self, db_path: str):
38
+ self.db_path = db_path
39
+ self.saver: Optional[AsyncSqliteSaver] = None
40
+
41
+ async def __aenter__(self) -> AsyncSqliteSaver:
42
+ """Initialize and return the AsyncSqliteSaver."""
43
+ conn = await aiosqlite.connect(self.db_path)
44
+ self.saver = AsyncSqliteSaver(conn)
45
+ return self.saver
46
+
47
+ async def __aexit__(
48
+ self,
49
+ exc_type: Optional[Type[BaseException]],
50
+ exc_val: Optional[BaseException],
51
+ exc_tb: Optional[TracebackType],
52
+ ) -> None:
53
+ """Clean up the SQLite connection."""
54
+ if self.saver and hasattr(self.saver, 'conn'):
55
+ await self.saver.conn.close()
56
+ self.saver = None
57
+
58
+ def get_checkpointer(db_path: str = "data/short_term.db") -> SQLiteCheckpointer:
59
+ """Create and return a SQLiteCheckpointer instance."""
60
+ return SQLiteCheckpointer(db_path)
61
+
62
+ async def create_agent_graph(checkpointer: AsyncSqliteSaver):
63
+ """Create an agent graph with SQLite-based memory persistence."""
64
+ builder = create_graph_builder()
65
+ graph = builder.compile(checkpointer=checkpointer)
66
+ return graph
67
+
68
+ # Export the graph builder functions
69
+ __all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer"]
app.py CHANGED
@@ -2,8 +2,16 @@ import uuid
2
  from langchain_core.messages import HumanMessage, AIMessage
3
  from langchain.schema.runnable.config import RunnableConfig
4
  import chainlit as cl
5
- from agent import graph_with_memory as graph
6
  from agent.utils.state import AgentState
 
 
 
 
 
 
 
 
7
 
8
  @cl.on_chat_start
9
  async def on_chat_start():
@@ -11,8 +19,8 @@ async def on_chat_start():
11
  session_id = str(uuid.uuid4())
12
  cl.user_session.set("session_id", session_id)
13
 
14
- # Initialize the conversation state with proper auth
15
- cl.user_session.set("messages", [])
16
 
17
  # Initialize config using stored session ID
18
  config = RunnableConfig(
@@ -24,10 +32,12 @@ async def on_chat_start():
24
 
25
  # Initialize empty state with auth
26
  try:
27
- await graph.ainvoke(
28
- AgentState(messages=[], context=[]),
29
- config=config
30
- )
 
 
31
  except Exception as e:
32
  print(f"Error initializing state: {str(e)}")
33
 
@@ -38,71 +48,95 @@ async def on_chat_start():
38
 
39
  @cl.on_message
40
  async def on_message(message: cl.Message):
 
41
  session_id = cl.user_session.get("session_id")
42
- print(f"Session ID: {session_id}")
43
  if not session_id:
44
  session_id = str(uuid.uuid4())
45
  cl.user_session.set("session_id", session_id)
46
 
 
 
 
 
 
 
 
 
 
 
47
  config = RunnableConfig(
48
  configurable={
49
  "thread_id": session_id,
50
- "checkpoint_ns": "default_namespace",
51
  "sessionId": session_id
52
  }
53
  )
54
 
55
- # Try to retrieve previous conversation state
56
  try:
57
- previous_state = await graph.aget_state(config)
58
- if previous_state and previous_state.values:
59
- previous_messages = previous_state.values.get('messages', [])
60
- print("Found previous state with messages:", len(previous_messages))
61
- else:
62
- print("Previous state empty or invalid")
63
- previous_messages = []
64
- current_messages = previous_messages + [HumanMessage(content=message.content)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  except Exception as e:
66
- print(f"Error retrieving previous state: {str(e)}")
67
- current_messages = [HumanMessage(content=message.content)]
68
-
69
- # Setup callback handler and final answer message
70
- cb = cl.LangchainCallbackHandler()
71
- final_answer = cl.Message(content="")
72
- await final_answer.send()
73
-
74
- loading_msg = None # Initialize reference to loading message
75
-
76
- # Stream the response
77
- async for chunk in graph.astream(
78
- AgentState(messages=current_messages, context=[]),
79
- config=RunnableConfig(
80
- configurable={
81
- "thread_id": session_id,
82
- }
83
- )
84
- ):
85
- for node, values in chunk.items():
86
- if node == "retrieve":
87
- loading_msg = cl.Message(content="πŸ” Searching knowledge base...", author="System")
88
- await loading_msg.send()
89
- elif values.get("messages"):
90
- last_message = values["messages"][-1]
91
- # Check for tool calls in additional_kwargs
92
- if hasattr(last_message, "additional_kwargs") and last_message.additional_kwargs.get("tool_calls"):
93
- tool_name = last_message.additional_kwargs["tool_calls"][0]["function"]["name"]
94
- if loading_msg:
95
- await loading_msg.remove()
96
- loading_msg = cl.Message(
97
- content=f"πŸ” Using {tool_name}...",
98
- author="Tool"
99
- )
100
- await loading_msg.send()
101
- # Only stream AI messages, skip tool outputs
102
- elif isinstance(last_message, AIMessage):
103
- if loading_msg:
104
- await loading_msg.remove()
105
- loading_msg = None
106
- await final_answer.stream_token(last_message.content)
107
-
108
- await final_answer.send()
 
2
  from langchain_core.messages import HumanMessage, AIMessage
3
  from langchain.schema.runnable.config import RunnableConfig
4
  import chainlit as cl
5
+ from agent import create_agent_graph, create_agent_graph_without_memory, get_checkpointer
6
  from agent.utils.state import AgentState
7
+ import os
8
+ import json
9
+
10
+ # Path to SQLite database for short-term memory
11
+ SHORT_TERM_MEMORY_DB_PATH = "data/short_term.db"
12
+
13
+ # Ensure the data directory exists
14
+ os.makedirs(os.path.dirname(SHORT_TERM_MEMORY_DB_PATH), exist_ok=True)
15
 
16
  @cl.on_chat_start
17
  async def on_chat_start():
 
19
  session_id = str(uuid.uuid4())
20
  cl.user_session.set("session_id", session_id)
21
 
22
+ # Initialize empty message history
23
+ cl.user_session.set("message_history", [])
24
 
25
  # Initialize config using stored session ID
26
  config = RunnableConfig(
 
32
 
33
  # Initialize empty state with auth
34
  try:
35
+ async with get_checkpointer(SHORT_TERM_MEMORY_DB_PATH) as saver:
36
+ graph = await create_agent_graph(saver)
37
+ initial_state = AgentState(messages=[], context=[])
38
+ await graph.ainvoke(initial_state, config=config)
39
+ # Store initial state as a serializable dict
40
+ cl.user_session.set("last_state", {"messages": [], "context": []})
41
  except Exception as e:
42
  print(f"Error initializing state: {str(e)}")
43
 
 
48
 
49
  @cl.on_message
50
  async def on_message(message: cl.Message):
51
+ # Get or create session ID
52
  session_id = cl.user_session.get("session_id")
 
53
  if not session_id:
54
  session_id = str(uuid.uuid4())
55
  cl.user_session.set("session_id", session_id)
56
 
57
+ print(f"Session ID: {session_id}")
58
+
59
+ # Get message history
60
+ message_history = cl.user_session.get("message_history", [])
61
+
62
+ # Add new message to history
63
+ current_message = HumanMessage(content=message.content)
64
+ message_history.append(current_message)
65
+ cl.user_session.set("message_history", message_history)
66
+
67
  config = RunnableConfig(
68
  configurable={
69
  "thread_id": session_id,
70
+ "checkpoint_ns": session_id, # Use session_id as namespace
71
  "sessionId": session_id
72
  }
73
  )
74
 
 
75
  try:
76
+ async with get_checkpointer(SHORT_TERM_MEMORY_DB_PATH) as saver:
77
+ # Create graph with SQLite memory
78
+ graph = await create_agent_graph(saver)
79
+
80
+ # Get the last state or create new one
81
+ last_state_dict = cl.user_session.get("last_state", {"messages": [], "context": []})
82
+
83
+ # Create new state with current message history
84
+ current_state = AgentState(
85
+ messages=message_history,
86
+ context=last_state_dict.get("context", [])
87
+ )
88
+
89
+ # Setup callback handler and final answer message
90
+ cb = cl.LangchainCallbackHandler()
91
+ final_answer = cl.Message(content="")
92
+ await final_answer.send()
93
+
94
+ loading_msg = None # Initialize reference to loading message
95
+ last_state = None # Track the final state
96
+
97
+ # Stream the response
98
+ async for chunk in graph.astream(
99
+ current_state,
100
+ config=config
101
+ ):
102
+ for node, values in chunk.items():
103
+ if node == "retrieve":
104
+ if loading_msg:
105
+ await loading_msg.remove()
106
+ loading_msg = cl.Message(content="πŸ” Searching knowledge base...", author="System")
107
+ await loading_msg.send()
108
+ elif values.get("messages"):
109
+ last_message = values["messages"][-1]
110
+ # Check for tool calls in additional_kwargs
111
+ if hasattr(last_message, "additional_kwargs") and last_message.additional_kwargs.get("tool_calls"):
112
+ tool_name = last_message.additional_kwargs["tool_calls"][0]["function"]["name"]
113
+ if loading_msg:
114
+ await loading_msg.remove()
115
+ loading_msg = cl.Message(
116
+ content=f"πŸ” Using {tool_name}...",
117
+ author="Tool"
118
+ )
119
+ await loading_msg.send()
120
+ # Only stream AI messages, skip tool outputs
121
+ elif isinstance(last_message, AIMessage):
122
+ if loading_msg:
123
+ await loading_msg.remove()
124
+ loading_msg = None
125
+ await final_answer.stream_token(last_message.content)
126
+ # Add AI message to history
127
+ message_history.append(last_message)
128
+ cl.user_session.set("message_history", message_history)
129
+ # Update last state
130
+ last_state = values
131
+
132
+ # Update the last state as a serializable dict
133
+ if last_state:
134
+ cl.user_session.set("last_state", {
135
+ "messages": [msg.content for msg in message_history],
136
+ "context": last_state.get("context", [])
137
+ })
138
+ await final_answer.send()
139
+
140
  except Exception as e:
141
+ print(f"Error in message handler: {str(e)}")
142
+ await cl.Message(content="I apologize, but I encountered an error processing your message. Please try again.").send()