Spaces:
Running
Running
add memory module
Browse files- agent/memory/__init__.py +4 -0
- agent/memory/analysis.py +67 -0
- agent/memory/example.py +52 -0
- agent/memory/manager.py +105 -0
- docker-compose.yml +23 -0
agent/memory/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .manager import MemoryManager, Message
|
2 |
+
from .analysis import MemoryAnalyzer, MemoryAnalysis
|
3 |
+
|
4 |
+
__all__ = ['MemoryManager', 'Message', 'MemoryAnalyzer', 'MemoryAnalysis']
|
agent/memory/analysis.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
from langchain_core.prompts import PromptTemplate
|
3 |
+
from langchain_openai import ChatOpenAI
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
class MemoryAnalysis(BaseModel):
|
7 |
+
is_important: bool
|
8 |
+
formatted_memory: Optional[str] = None
|
9 |
+
|
10 |
+
MEMORY_ANALYSIS_PROMPT = PromptTemplate.from_template("""Extract and format important personal facts about the user from their message.
|
11 |
+
Focus on the actual information, not meta-commentary or requests.
|
12 |
+
|
13 |
+
Important facts include:
|
14 |
+
- Personal details (name, age, location)
|
15 |
+
- Professional info (job, education, skills)
|
16 |
+
- Preferences (likes, dislikes, favorites)
|
17 |
+
- Life circumstances (family, relationships)
|
18 |
+
- Significant experiences or achievements
|
19 |
+
- Personal goals or aspirations
|
20 |
+
|
21 |
+
Rules:
|
22 |
+
1. Only extract actual facts, not requests or commentary
|
23 |
+
2. Convert facts into clear, third-person statements
|
24 |
+
3. If no actual facts are present, mark as not important
|
25 |
+
4. Remove conversational elements and focus on core information
|
26 |
+
|
27 |
+
Examples:
|
28 |
+
Input: "Hey, could you remember that I love Star Wars?"
|
29 |
+
Output: {
|
30 |
+
"is_important": true,
|
31 |
+
"formatted_memory": "Loves Star Wars"
|
32 |
+
}
|
33 |
+
|
34 |
+
Input: "Can you remember my details for next time?"
|
35 |
+
Output: {
|
36 |
+
"is_important": false,
|
37 |
+
"formatted_memory": null
|
38 |
+
}
|
39 |
+
|
40 |
+
Message: {message}
|
41 |
+
Output:""")
|
42 |
+
|
43 |
+
class MemoryAnalyzer:
|
44 |
+
def __init__(self, temperature: float = 0.1):
|
45 |
+
self.llm = ChatOpenAI(temperature=temperature)
|
46 |
+
|
47 |
+
async def analyze_memory(self, message: str) -> MemoryAnalysis:
|
48 |
+
"""Analyze a message to determine importance and format if needed."""
|
49 |
+
prompt = MEMORY_ANALYSIS_PROMPT.format(message=message)
|
50 |
+
response = await self.llm.ainvoke(prompt)
|
51 |
+
|
52 |
+
# Parse the response into a MemoryAnalysis object
|
53 |
+
try:
|
54 |
+
# Extract the JSON-like content from the response
|
55 |
+
content = response.content
|
56 |
+
if isinstance(content, str):
|
57 |
+
# Convert string representation to dict
|
58 |
+
import json
|
59 |
+
content = json.loads(content)
|
60 |
+
|
61 |
+
return MemoryAnalysis(
|
62 |
+
is_important=content.get("is_important", False),
|
63 |
+
formatted_memory=content.get("formatted_memory")
|
64 |
+
)
|
65 |
+
except Exception as e:
|
66 |
+
# If parsing fails, return a safe default
|
67 |
+
return MemoryAnalysis(is_important=False, formatted_memory=None)
|
agent/memory/example.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
from .manager import MemoryManager, Message
|
3 |
+
from .analysis import MemoryAnalyzer
|
4 |
+
import uuid
|
5 |
+
|
6 |
+
async def main():
|
7 |
+
# Initialize the memory system
|
8 |
+
memory_manager = MemoryManager(
|
9 |
+
qdrant_url="http://localhost:6333",
|
10 |
+
sqlite_path="data/short_term.db"
|
11 |
+
)
|
12 |
+
memory_analyzer = MemoryAnalyzer()
|
13 |
+
|
14 |
+
try:
|
15 |
+
# Example: Store a message in short-term memory
|
16 |
+
message = Message(
|
17 |
+
content="I really enjoy programming in Python and building AI applications.",
|
18 |
+
type="human",
|
19 |
+
timestamp=datetime.now()
|
20 |
+
)
|
21 |
+
await memory_manager.store_short_term(message)
|
22 |
+
|
23 |
+
# Analyze the message for long-term storage
|
24 |
+
analysis = await memory_analyzer.analyze_memory(message.content)
|
25 |
+
|
26 |
+
if analysis.is_important and analysis.formatted_memory:
|
27 |
+
# Check for similar existing memories
|
28 |
+
similar = memory_manager.find_similar_memory(analysis.formatted_memory)
|
29 |
+
|
30 |
+
if not similar:
|
31 |
+
# Store in long-term memory if no similar memory exists
|
32 |
+
memory_manager.store_long_term(
|
33 |
+
text=analysis.formatted_memory,
|
34 |
+
metadata={
|
35 |
+
"id": str(uuid.uuid4()),
|
36 |
+
"timestamp": datetime.now().isoformat(),
|
37 |
+
"source_message": message.content
|
38 |
+
}
|
39 |
+
)
|
40 |
+
|
41 |
+
# Retrieve recent messages from short-term memory
|
42 |
+
recent_messages = memory_manager.get_recent_messages(limit=5)
|
43 |
+
print("Recent messages:")
|
44 |
+
for msg in recent_messages:
|
45 |
+
print(f"- {msg.timestamp}: {msg.content}")
|
46 |
+
|
47 |
+
finally:
|
48 |
+
memory_manager.close()
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
import asyncio
|
52 |
+
asyncio.run(main())
|
agent/memory/manager.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
import sqlite3
|
3 |
+
import uuid
|
4 |
+
from typing import Optional, List
|
5 |
+
from pydantic import BaseModel
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
from qdrant_client import QdrantClient
|
8 |
+
from qdrant_client.models import PointStruct
|
9 |
+
import os
|
10 |
+
|
11 |
+
class MemoryAnalysis(BaseModel):
|
12 |
+
is_important: bool
|
13 |
+
formatted_memory: Optional[str] = None
|
14 |
+
|
15 |
+
class Message(BaseModel):
|
16 |
+
content: str
|
17 |
+
type: str # "human" or "ai"
|
18 |
+
timestamp: datetime
|
19 |
+
|
20 |
+
class MemoryManager:
|
21 |
+
SIMILARITY_THRESHOLD = 0.9
|
22 |
+
|
23 |
+
def __init__(self, qdrant_url: str = "http://localhost:6333", sqlite_path: str = "data/short_term.db"):
|
24 |
+
# Initialize vector store (Qdrant)
|
25 |
+
self.model = SentenceTransformer("all-MiniLM-L6-v2")
|
26 |
+
self.qdrant = QdrantClient(url=qdrant_url)
|
27 |
+
|
28 |
+
# Ensure the collection exists
|
29 |
+
try:
|
30 |
+
self.qdrant.get_collection("memories")
|
31 |
+
except:
|
32 |
+
self.qdrant.create_collection(
|
33 |
+
collection_name="memories",
|
34 |
+
vectors_config={
|
35 |
+
"size": self.model.get_sentence_embedding_dimension(),
|
36 |
+
"distance": "Cosine"
|
37 |
+
}
|
38 |
+
)
|
39 |
+
|
40 |
+
# Initialize SQLite for short-term memory
|
41 |
+
os.makedirs(os.path.dirname(sqlite_path), exist_ok=True)
|
42 |
+
self.sqlite_conn = sqlite3.connect(sqlite_path)
|
43 |
+
self._init_sqlite()
|
44 |
+
|
45 |
+
def _init_sqlite(self):
|
46 |
+
"""Initialize SQLite tables for short-term memory."""
|
47 |
+
cursor = self.sqlite_conn.cursor()
|
48 |
+
cursor.execute("""
|
49 |
+
CREATE TABLE IF NOT EXISTS conversations (
|
50 |
+
id TEXT PRIMARY KEY,
|
51 |
+
content TEXT NOT NULL,
|
52 |
+
type TEXT NOT NULL,
|
53 |
+
timestamp DATETIME NOT NULL
|
54 |
+
)
|
55 |
+
""")
|
56 |
+
self.sqlite_conn.commit()
|
57 |
+
|
58 |
+
async def store_short_term(self, message: Message) -> None:
|
59 |
+
"""Store a message in short-term memory (SQLite)."""
|
60 |
+
cursor = self.sqlite_conn.cursor()
|
61 |
+
cursor.execute(
|
62 |
+
"INSERT INTO conversations (id, content, type, timestamp) VALUES (?, ?, ?, ?)",
|
63 |
+
(str(uuid.uuid4()), message.content, message.type, message.timestamp.isoformat())
|
64 |
+
)
|
65 |
+
self.sqlite_conn.commit()
|
66 |
+
|
67 |
+
def get_recent_messages(self, limit: int = 10) -> List[Message]:
|
68 |
+
"""Retrieve recent messages from short-term memory."""
|
69 |
+
cursor = self.sqlite_conn.cursor()
|
70 |
+
cursor.execute(
|
71 |
+
"SELECT content, type, timestamp FROM conversations ORDER BY timestamp DESC LIMIT ?",
|
72 |
+
(limit,)
|
73 |
+
)
|
74 |
+
return [
|
75 |
+
Message(content=row[0], type=row[1], timestamp=datetime.fromisoformat(row[2]))
|
76 |
+
for row in cursor.fetchall()
|
77 |
+
]
|
78 |
+
|
79 |
+
def store_long_term(self, text: str, metadata: dict) -> None:
|
80 |
+
"""Store a memory in long-term memory (Qdrant)."""
|
81 |
+
embedding = self.model.encode(text)
|
82 |
+
self.qdrant.upsert(
|
83 |
+
collection_name="memories",
|
84 |
+
points=[PointStruct(
|
85 |
+
id=metadata["id"],
|
86 |
+
vector=embedding.tolist(),
|
87 |
+
payload={"text": text, **metadata}
|
88 |
+
)]
|
89 |
+
)
|
90 |
+
|
91 |
+
def find_similar_memory(self, text: str) -> Optional[str]:
|
92 |
+
"""Find similar existing memory using cosine similarity."""
|
93 |
+
embedding = self.model.encode(text)
|
94 |
+
results = self.qdrant.search(
|
95 |
+
collection_name="memories",
|
96 |
+
query_vector=embedding.tolist(),
|
97 |
+
limit=1
|
98 |
+
)
|
99 |
+
if results and results[0].score > self.SIMILARITY_THRESHOLD:
|
100 |
+
return results[0].payload["text"]
|
101 |
+
return None
|
102 |
+
|
103 |
+
def close(self):
|
104 |
+
"""Close database connections."""
|
105 |
+
self.sqlite_conn.close()
|
docker-compose.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: "3.8"
|
2 |
+
|
3 |
+
services:
|
4 |
+
qdrant:
|
5 |
+
image: qdrant/qdrant:latest
|
6 |
+
ports:
|
7 |
+
- "6333:6333"
|
8 |
+
volumes:
|
9 |
+
- ./data/long_term_memory:/qdrant/storage
|
10 |
+
restart: unless-stopped
|
11 |
+
|
12 |
+
app:
|
13 |
+
build: .
|
14 |
+
ports:
|
15 |
+
- "8000:8000"
|
16 |
+
volumes:
|
17 |
+
- .:/app
|
18 |
+
- ./data/short_term_memory:/app/data
|
19 |
+
environment:
|
20 |
+
- QDRANT_URL=http://qdrant:6333
|
21 |
+
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
22 |
+
depends_on:
|
23 |
+
- qdrant
|