ric9176 commited on
Commit
cdceb53
·
1 Parent(s): f5dd805

add memory module

Browse files
agent/memory/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .manager import MemoryManager, Message
2
+ from .analysis import MemoryAnalyzer, MemoryAnalysis
3
+
4
+ __all__ = ['MemoryManager', 'Message', 'MemoryAnalyzer', 'MemoryAnalysis']
agent/memory/analysis.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from langchain_core.prompts import PromptTemplate
3
+ from langchain_openai import ChatOpenAI
4
+ from pydantic import BaseModel
5
+
6
+ class MemoryAnalysis(BaseModel):
7
+ is_important: bool
8
+ formatted_memory: Optional[str] = None
9
+
10
+ MEMORY_ANALYSIS_PROMPT = PromptTemplate.from_template("""Extract and format important personal facts about the user from their message.
11
+ Focus on the actual information, not meta-commentary or requests.
12
+
13
+ Important facts include:
14
+ - Personal details (name, age, location)
15
+ - Professional info (job, education, skills)
16
+ - Preferences (likes, dislikes, favorites)
17
+ - Life circumstances (family, relationships)
18
+ - Significant experiences or achievements
19
+ - Personal goals or aspirations
20
+
21
+ Rules:
22
+ 1. Only extract actual facts, not requests or commentary
23
+ 2. Convert facts into clear, third-person statements
24
+ 3. If no actual facts are present, mark as not important
25
+ 4. Remove conversational elements and focus on core information
26
+
27
+ Examples:
28
+ Input: "Hey, could you remember that I love Star Wars?"
29
+ Output: {
30
+ "is_important": true,
31
+ "formatted_memory": "Loves Star Wars"
32
+ }
33
+
34
+ Input: "Can you remember my details for next time?"
35
+ Output: {
36
+ "is_important": false,
37
+ "formatted_memory": null
38
+ }
39
+
40
+ Message: {message}
41
+ Output:""")
42
+
43
+ class MemoryAnalyzer:
44
+ def __init__(self, temperature: float = 0.1):
45
+ self.llm = ChatOpenAI(temperature=temperature)
46
+
47
+ async def analyze_memory(self, message: str) -> MemoryAnalysis:
48
+ """Analyze a message to determine importance and format if needed."""
49
+ prompt = MEMORY_ANALYSIS_PROMPT.format(message=message)
50
+ response = await self.llm.ainvoke(prompt)
51
+
52
+ # Parse the response into a MemoryAnalysis object
53
+ try:
54
+ # Extract the JSON-like content from the response
55
+ content = response.content
56
+ if isinstance(content, str):
57
+ # Convert string representation to dict
58
+ import json
59
+ content = json.loads(content)
60
+
61
+ return MemoryAnalysis(
62
+ is_important=content.get("is_important", False),
63
+ formatted_memory=content.get("formatted_memory")
64
+ )
65
+ except Exception as e:
66
+ # If parsing fails, return a safe default
67
+ return MemoryAnalysis(is_important=False, formatted_memory=None)
agent/memory/example.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from .manager import MemoryManager, Message
3
+ from .analysis import MemoryAnalyzer
4
+ import uuid
5
+
6
+ async def main():
7
+ # Initialize the memory system
8
+ memory_manager = MemoryManager(
9
+ qdrant_url="http://localhost:6333",
10
+ sqlite_path="data/short_term.db"
11
+ )
12
+ memory_analyzer = MemoryAnalyzer()
13
+
14
+ try:
15
+ # Example: Store a message in short-term memory
16
+ message = Message(
17
+ content="I really enjoy programming in Python and building AI applications.",
18
+ type="human",
19
+ timestamp=datetime.now()
20
+ )
21
+ await memory_manager.store_short_term(message)
22
+
23
+ # Analyze the message for long-term storage
24
+ analysis = await memory_analyzer.analyze_memory(message.content)
25
+
26
+ if analysis.is_important and analysis.formatted_memory:
27
+ # Check for similar existing memories
28
+ similar = memory_manager.find_similar_memory(analysis.formatted_memory)
29
+
30
+ if not similar:
31
+ # Store in long-term memory if no similar memory exists
32
+ memory_manager.store_long_term(
33
+ text=analysis.formatted_memory,
34
+ metadata={
35
+ "id": str(uuid.uuid4()),
36
+ "timestamp": datetime.now().isoformat(),
37
+ "source_message": message.content
38
+ }
39
+ )
40
+
41
+ # Retrieve recent messages from short-term memory
42
+ recent_messages = memory_manager.get_recent_messages(limit=5)
43
+ print("Recent messages:")
44
+ for msg in recent_messages:
45
+ print(f"- {msg.timestamp}: {msg.content}")
46
+
47
+ finally:
48
+ memory_manager.close()
49
+
50
+ if __name__ == "__main__":
51
+ import asyncio
52
+ asyncio.run(main())
agent/memory/manager.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import sqlite3
3
+ import uuid
4
+ from typing import Optional, List
5
+ from pydantic import BaseModel
6
+ from sentence_transformers import SentenceTransformer
7
+ from qdrant_client import QdrantClient
8
+ from qdrant_client.models import PointStruct
9
+ import os
10
+
11
+ class MemoryAnalysis(BaseModel):
12
+ is_important: bool
13
+ formatted_memory: Optional[str] = None
14
+
15
+ class Message(BaseModel):
16
+ content: str
17
+ type: str # "human" or "ai"
18
+ timestamp: datetime
19
+
20
+ class MemoryManager:
21
+ SIMILARITY_THRESHOLD = 0.9
22
+
23
+ def __init__(self, qdrant_url: str = "http://localhost:6333", sqlite_path: str = "data/short_term.db"):
24
+ # Initialize vector store (Qdrant)
25
+ self.model = SentenceTransformer("all-MiniLM-L6-v2")
26
+ self.qdrant = QdrantClient(url=qdrant_url)
27
+
28
+ # Ensure the collection exists
29
+ try:
30
+ self.qdrant.get_collection("memories")
31
+ except:
32
+ self.qdrant.create_collection(
33
+ collection_name="memories",
34
+ vectors_config={
35
+ "size": self.model.get_sentence_embedding_dimension(),
36
+ "distance": "Cosine"
37
+ }
38
+ )
39
+
40
+ # Initialize SQLite for short-term memory
41
+ os.makedirs(os.path.dirname(sqlite_path), exist_ok=True)
42
+ self.sqlite_conn = sqlite3.connect(sqlite_path)
43
+ self._init_sqlite()
44
+
45
+ def _init_sqlite(self):
46
+ """Initialize SQLite tables for short-term memory."""
47
+ cursor = self.sqlite_conn.cursor()
48
+ cursor.execute("""
49
+ CREATE TABLE IF NOT EXISTS conversations (
50
+ id TEXT PRIMARY KEY,
51
+ content TEXT NOT NULL,
52
+ type TEXT NOT NULL,
53
+ timestamp DATETIME NOT NULL
54
+ )
55
+ """)
56
+ self.sqlite_conn.commit()
57
+
58
+ async def store_short_term(self, message: Message) -> None:
59
+ """Store a message in short-term memory (SQLite)."""
60
+ cursor = self.sqlite_conn.cursor()
61
+ cursor.execute(
62
+ "INSERT INTO conversations (id, content, type, timestamp) VALUES (?, ?, ?, ?)",
63
+ (str(uuid.uuid4()), message.content, message.type, message.timestamp.isoformat())
64
+ )
65
+ self.sqlite_conn.commit()
66
+
67
+ def get_recent_messages(self, limit: int = 10) -> List[Message]:
68
+ """Retrieve recent messages from short-term memory."""
69
+ cursor = self.sqlite_conn.cursor()
70
+ cursor.execute(
71
+ "SELECT content, type, timestamp FROM conversations ORDER BY timestamp DESC LIMIT ?",
72
+ (limit,)
73
+ )
74
+ return [
75
+ Message(content=row[0], type=row[1], timestamp=datetime.fromisoformat(row[2]))
76
+ for row in cursor.fetchall()
77
+ ]
78
+
79
+ def store_long_term(self, text: str, metadata: dict) -> None:
80
+ """Store a memory in long-term memory (Qdrant)."""
81
+ embedding = self.model.encode(text)
82
+ self.qdrant.upsert(
83
+ collection_name="memories",
84
+ points=[PointStruct(
85
+ id=metadata["id"],
86
+ vector=embedding.tolist(),
87
+ payload={"text": text, **metadata}
88
+ )]
89
+ )
90
+
91
+ def find_similar_memory(self, text: str) -> Optional[str]:
92
+ """Find similar existing memory using cosine similarity."""
93
+ embedding = self.model.encode(text)
94
+ results = self.qdrant.search(
95
+ collection_name="memories",
96
+ query_vector=embedding.tolist(),
97
+ limit=1
98
+ )
99
+ if results and results[0].score > self.SIMILARITY_THRESHOLD:
100
+ return results[0].payload["text"]
101
+ return None
102
+
103
+ def close(self):
104
+ """Close database connections."""
105
+ self.sqlite_conn.close()
docker-compose.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.8"
2
+
3
+ services:
4
+ qdrant:
5
+ image: qdrant/qdrant:latest
6
+ ports:
7
+ - "6333:6333"
8
+ volumes:
9
+ - ./data/long_term_memory:/qdrant/storage
10
+ restart: unless-stopped
11
+
12
+ app:
13
+ build: .
14
+ ports:
15
+ - "8000:8000"
16
+ volumes:
17
+ - .:/app
18
+ - ./data/short_term_memory:/app/data
19
+ environment:
20
+ - QDRANT_URL=http://qdrant:6333
21
+ - OPENAI_API_KEY=${OPENAI_API_KEY}
22
+ depends_on:
23
+ - qdrant