Spaces:
Sleeping
Sleeping
complete refactor to use AsyncSqliteSaver
Browse files- agent/__init__.py +2 -2
- agent/agent.py +0 -50
- agent/graph.py +69 -0
- app.py +95 -61
agent/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
from agent.
|
2 |
|
3 |
-
__all__ = ["
|
|
|
1 |
+
from agent.graph import create_agent_graph, create_agent_graph_without_memory, get_checkpointer
|
2 |
|
3 |
+
__all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer"]
|
agent/agent.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
from langgraph.graph import StateGraph
|
2 |
-
from langgraph.checkpoint.memory import MemorySaver
|
3 |
-
|
4 |
-
from agent.utils.state import AgentState
|
5 |
-
from agent.utils.nodes import call_model, tool_node, should_continue
|
6 |
-
|
7 |
-
def create_agent_graph():
|
8 |
-
# Create the graph
|
9 |
-
builder = StateGraph(AgentState)
|
10 |
-
|
11 |
-
# Add nodes
|
12 |
-
builder.add_node("agent", call_model)
|
13 |
-
builder.add_node("action", tool_node)
|
14 |
-
|
15 |
-
# Update edges
|
16 |
-
builder.set_entry_point("agent")
|
17 |
-
builder.add_conditional_edges(
|
18 |
-
"agent",
|
19 |
-
should_continue,
|
20 |
-
)
|
21 |
-
builder.add_edge("action", "agent")
|
22 |
-
|
23 |
-
# Initialize memory saver for conversation persistence
|
24 |
-
memory = MemorySaver()
|
25 |
-
|
26 |
-
# Compile the graph with memory
|
27 |
-
return builder.compile(checkpointer=memory)
|
28 |
-
|
29 |
-
def create_agent_graph_without_memory():
|
30 |
-
# Create the graph
|
31 |
-
builder = StateGraph(AgentState)
|
32 |
-
|
33 |
-
# Add nodes
|
34 |
-
builder.add_node("agent", call_model)
|
35 |
-
builder.add_node("action", tool_node)
|
36 |
-
|
37 |
-
# Update edges
|
38 |
-
builder.set_entry_point("agent")
|
39 |
-
builder.add_conditional_edges(
|
40 |
-
"agent",
|
41 |
-
should_continue,
|
42 |
-
)
|
43 |
-
builder.add_edge("action", "agent")
|
44 |
-
|
45 |
-
# Compile the graph without memory
|
46 |
-
return builder.compile()
|
47 |
-
|
48 |
-
# Create both graph variants
|
49 |
-
graph_with_memory = create_agent_graph()
|
50 |
-
graph = create_agent_graph_without_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/graph.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langgraph.graph import StateGraph
|
2 |
+
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
3 |
+
import aiosqlite
|
4 |
+
from types import TracebackType
|
5 |
+
from typing import Optional, Type
|
6 |
+
|
7 |
+
from agent.utils.state import AgentState
|
8 |
+
from agent.utils.nodes import call_model, tool_node, should_continue
|
9 |
+
|
10 |
+
def create_graph_builder():
|
11 |
+
"""Create a base graph builder with nodes and edges configured."""
|
12 |
+
# Create the graph
|
13 |
+
builder = StateGraph(AgentState)
|
14 |
+
|
15 |
+
# Add nodes
|
16 |
+
builder.add_node("agent", call_model)
|
17 |
+
builder.add_node("action", tool_node)
|
18 |
+
|
19 |
+
# Update edges
|
20 |
+
builder.set_entry_point("agent")
|
21 |
+
builder.add_conditional_edges(
|
22 |
+
"agent",
|
23 |
+
should_continue,
|
24 |
+
)
|
25 |
+
builder.add_edge("action", "agent")
|
26 |
+
|
27 |
+
return builder
|
28 |
+
|
29 |
+
def create_agent_graph_without_memory():
|
30 |
+
"""Create an agent graph without memory persistence."""
|
31 |
+
builder = create_graph_builder()
|
32 |
+
return builder.compile()
|
33 |
+
|
34 |
+
class SQLiteCheckpointer:
|
35 |
+
"""Context manager for SQLite checkpointing."""
|
36 |
+
|
37 |
+
def __init__(self, db_path: str):
|
38 |
+
self.db_path = db_path
|
39 |
+
self.saver: Optional[AsyncSqliteSaver] = None
|
40 |
+
|
41 |
+
async def __aenter__(self) -> AsyncSqliteSaver:
|
42 |
+
"""Initialize and return the AsyncSqliteSaver."""
|
43 |
+
conn = await aiosqlite.connect(self.db_path)
|
44 |
+
self.saver = AsyncSqliteSaver(conn)
|
45 |
+
return self.saver
|
46 |
+
|
47 |
+
async def __aexit__(
|
48 |
+
self,
|
49 |
+
exc_type: Optional[Type[BaseException]],
|
50 |
+
exc_val: Optional[BaseException],
|
51 |
+
exc_tb: Optional[TracebackType],
|
52 |
+
) -> None:
|
53 |
+
"""Clean up the SQLite connection."""
|
54 |
+
if self.saver and hasattr(self.saver, 'conn'):
|
55 |
+
await self.saver.conn.close()
|
56 |
+
self.saver = None
|
57 |
+
|
58 |
+
def get_checkpointer(db_path: str = "data/short_term.db") -> SQLiteCheckpointer:
|
59 |
+
"""Create and return a SQLiteCheckpointer instance."""
|
60 |
+
return SQLiteCheckpointer(db_path)
|
61 |
+
|
62 |
+
async def create_agent_graph(checkpointer: AsyncSqliteSaver):
|
63 |
+
"""Create an agent graph with SQLite-based memory persistence."""
|
64 |
+
builder = create_graph_builder()
|
65 |
+
graph = builder.compile(checkpointer=checkpointer)
|
66 |
+
return graph
|
67 |
+
|
68 |
+
# Export the graph builder functions
|
69 |
+
__all__ = ["create_agent_graph", "create_agent_graph_without_memory", "get_checkpointer"]
|
app.py
CHANGED
@@ -2,8 +2,16 @@ import uuid
|
|
2 |
from langchain_core.messages import HumanMessage, AIMessage
|
3 |
from langchain.schema.runnable.config import RunnableConfig
|
4 |
import chainlit as cl
|
5 |
-
from agent import
|
6 |
from agent.utils.state import AgentState
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
@cl.on_chat_start
|
9 |
async def on_chat_start():
|
@@ -11,8 +19,8 @@ async def on_chat_start():
|
|
11 |
session_id = str(uuid.uuid4())
|
12 |
cl.user_session.set("session_id", session_id)
|
13 |
|
14 |
-
# Initialize
|
15 |
-
cl.user_session.set("
|
16 |
|
17 |
# Initialize config using stored session ID
|
18 |
config = RunnableConfig(
|
@@ -24,10 +32,12 @@ async def on_chat_start():
|
|
24 |
|
25 |
# Initialize empty state with auth
|
26 |
try:
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
31 |
except Exception as e:
|
32 |
print(f"Error initializing state: {str(e)}")
|
33 |
|
@@ -38,71 +48,95 @@ async def on_chat_start():
|
|
38 |
|
39 |
@cl.on_message
|
40 |
async def on_message(message: cl.Message):
|
|
|
41 |
session_id = cl.user_session.get("session_id")
|
42 |
-
print(f"Session ID: {session_id}")
|
43 |
if not session_id:
|
44 |
session_id = str(uuid.uuid4())
|
45 |
cl.user_session.set("session_id", session_id)
|
46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
config = RunnableConfig(
|
48 |
configurable={
|
49 |
"thread_id": session_id,
|
50 |
-
"checkpoint_ns":
|
51 |
"sessionId": session_id
|
52 |
}
|
53 |
)
|
54 |
|
55 |
-
# Try to retrieve previous conversation state
|
56 |
try:
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
except Exception as e:
|
66 |
-
print(f"Error
|
67 |
-
|
68 |
-
|
69 |
-
# Setup callback handler and final answer message
|
70 |
-
cb = cl.LangchainCallbackHandler()
|
71 |
-
final_answer = cl.Message(content="")
|
72 |
-
await final_answer.send()
|
73 |
-
|
74 |
-
loading_msg = None # Initialize reference to loading message
|
75 |
-
|
76 |
-
# Stream the response
|
77 |
-
async for chunk in graph.astream(
|
78 |
-
AgentState(messages=current_messages, context=[]),
|
79 |
-
config=RunnableConfig(
|
80 |
-
configurable={
|
81 |
-
"thread_id": session_id,
|
82 |
-
}
|
83 |
-
)
|
84 |
-
):
|
85 |
-
for node, values in chunk.items():
|
86 |
-
if node == "retrieve":
|
87 |
-
loading_msg = cl.Message(content="π Searching knowledge base...", author="System")
|
88 |
-
await loading_msg.send()
|
89 |
-
elif values.get("messages"):
|
90 |
-
last_message = values["messages"][-1]
|
91 |
-
# Check for tool calls in additional_kwargs
|
92 |
-
if hasattr(last_message, "additional_kwargs") and last_message.additional_kwargs.get("tool_calls"):
|
93 |
-
tool_name = last_message.additional_kwargs["tool_calls"][0]["function"]["name"]
|
94 |
-
if loading_msg:
|
95 |
-
await loading_msg.remove()
|
96 |
-
loading_msg = cl.Message(
|
97 |
-
content=f"π Using {tool_name}...",
|
98 |
-
author="Tool"
|
99 |
-
)
|
100 |
-
await loading_msg.send()
|
101 |
-
# Only stream AI messages, skip tool outputs
|
102 |
-
elif isinstance(last_message, AIMessage):
|
103 |
-
if loading_msg:
|
104 |
-
await loading_msg.remove()
|
105 |
-
loading_msg = None
|
106 |
-
await final_answer.stream_token(last_message.content)
|
107 |
-
|
108 |
-
await final_answer.send()
|
|
|
2 |
from langchain_core.messages import HumanMessage, AIMessage
|
3 |
from langchain.schema.runnable.config import RunnableConfig
|
4 |
import chainlit as cl
|
5 |
+
from agent import create_agent_graph, create_agent_graph_without_memory, get_checkpointer
|
6 |
from agent.utils.state import AgentState
|
7 |
+
import os
|
8 |
+
import json
|
9 |
+
|
10 |
+
# Path to SQLite database for short-term memory
|
11 |
+
SHORT_TERM_MEMORY_DB_PATH = "data/short_term.db"
|
12 |
+
|
13 |
+
# Ensure the data directory exists
|
14 |
+
os.makedirs(os.path.dirname(SHORT_TERM_MEMORY_DB_PATH), exist_ok=True)
|
15 |
|
16 |
@cl.on_chat_start
|
17 |
async def on_chat_start():
|
|
|
19 |
session_id = str(uuid.uuid4())
|
20 |
cl.user_session.set("session_id", session_id)
|
21 |
|
22 |
+
# Initialize empty message history
|
23 |
+
cl.user_session.set("message_history", [])
|
24 |
|
25 |
# Initialize config using stored session ID
|
26 |
config = RunnableConfig(
|
|
|
32 |
|
33 |
# Initialize empty state with auth
|
34 |
try:
|
35 |
+
async with get_checkpointer(SHORT_TERM_MEMORY_DB_PATH) as saver:
|
36 |
+
graph = await create_agent_graph(saver)
|
37 |
+
initial_state = AgentState(messages=[], context=[])
|
38 |
+
await graph.ainvoke(initial_state, config=config)
|
39 |
+
# Store initial state as a serializable dict
|
40 |
+
cl.user_session.set("last_state", {"messages": [], "context": []})
|
41 |
except Exception as e:
|
42 |
print(f"Error initializing state: {str(e)}")
|
43 |
|
|
|
48 |
|
49 |
@cl.on_message
|
50 |
async def on_message(message: cl.Message):
|
51 |
+
# Get or create session ID
|
52 |
session_id = cl.user_session.get("session_id")
|
|
|
53 |
if not session_id:
|
54 |
session_id = str(uuid.uuid4())
|
55 |
cl.user_session.set("session_id", session_id)
|
56 |
|
57 |
+
print(f"Session ID: {session_id}")
|
58 |
+
|
59 |
+
# Get message history
|
60 |
+
message_history = cl.user_session.get("message_history", [])
|
61 |
+
|
62 |
+
# Add new message to history
|
63 |
+
current_message = HumanMessage(content=message.content)
|
64 |
+
message_history.append(current_message)
|
65 |
+
cl.user_session.set("message_history", message_history)
|
66 |
+
|
67 |
config = RunnableConfig(
|
68 |
configurable={
|
69 |
"thread_id": session_id,
|
70 |
+
"checkpoint_ns": session_id, # Use session_id as namespace
|
71 |
"sessionId": session_id
|
72 |
}
|
73 |
)
|
74 |
|
|
|
75 |
try:
|
76 |
+
async with get_checkpointer(SHORT_TERM_MEMORY_DB_PATH) as saver:
|
77 |
+
# Create graph with SQLite memory
|
78 |
+
graph = await create_agent_graph(saver)
|
79 |
+
|
80 |
+
# Get the last state or create new one
|
81 |
+
last_state_dict = cl.user_session.get("last_state", {"messages": [], "context": []})
|
82 |
+
|
83 |
+
# Create new state with current message history
|
84 |
+
current_state = AgentState(
|
85 |
+
messages=message_history,
|
86 |
+
context=last_state_dict.get("context", [])
|
87 |
+
)
|
88 |
+
|
89 |
+
# Setup callback handler and final answer message
|
90 |
+
cb = cl.LangchainCallbackHandler()
|
91 |
+
final_answer = cl.Message(content="")
|
92 |
+
await final_answer.send()
|
93 |
+
|
94 |
+
loading_msg = None # Initialize reference to loading message
|
95 |
+
last_state = None # Track the final state
|
96 |
+
|
97 |
+
# Stream the response
|
98 |
+
async for chunk in graph.astream(
|
99 |
+
current_state,
|
100 |
+
config=config
|
101 |
+
):
|
102 |
+
for node, values in chunk.items():
|
103 |
+
if node == "retrieve":
|
104 |
+
if loading_msg:
|
105 |
+
await loading_msg.remove()
|
106 |
+
loading_msg = cl.Message(content="π Searching knowledge base...", author="System")
|
107 |
+
await loading_msg.send()
|
108 |
+
elif values.get("messages"):
|
109 |
+
last_message = values["messages"][-1]
|
110 |
+
# Check for tool calls in additional_kwargs
|
111 |
+
if hasattr(last_message, "additional_kwargs") and last_message.additional_kwargs.get("tool_calls"):
|
112 |
+
tool_name = last_message.additional_kwargs["tool_calls"][0]["function"]["name"]
|
113 |
+
if loading_msg:
|
114 |
+
await loading_msg.remove()
|
115 |
+
loading_msg = cl.Message(
|
116 |
+
content=f"π Using {tool_name}...",
|
117 |
+
author="Tool"
|
118 |
+
)
|
119 |
+
await loading_msg.send()
|
120 |
+
# Only stream AI messages, skip tool outputs
|
121 |
+
elif isinstance(last_message, AIMessage):
|
122 |
+
if loading_msg:
|
123 |
+
await loading_msg.remove()
|
124 |
+
loading_msg = None
|
125 |
+
await final_answer.stream_token(last_message.content)
|
126 |
+
# Add AI message to history
|
127 |
+
message_history.append(last_message)
|
128 |
+
cl.user_session.set("message_history", message_history)
|
129 |
+
# Update last state
|
130 |
+
last_state = values
|
131 |
+
|
132 |
+
# Update the last state as a serializable dict
|
133 |
+
if last_state:
|
134 |
+
cl.user_session.set("last_state", {
|
135 |
+
"messages": [msg.content for msg in message_history],
|
136 |
+
"context": last_state.get("context", [])
|
137 |
+
})
|
138 |
+
await final_answer.send()
|
139 |
+
|
140 |
except Exception as e:
|
141 |
+
print(f"Error in message handler: {str(e)}")
|
142 |
+
await cl.Message(content="I apologize, but I encountered an error processing your message. Please try again.").send()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|