ssandy_agents / agent.py
Sheshank Joshi
reasoning agent
9fced79
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain.tools.retriever import create_retriever_tool
from langchain_core.tools import BaseTool
from langgraph.graph import START, StateGraph, MessagesState, END
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain.vectorstores import VectorStore
from langchain_core.language_models import BaseChatModel
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
# from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_groq import ChatGroq
from basic_tools import *
from typing import List
import numpy as np
from datetime import datetime, timedelta
from sentence_transformers import SentenceTransformer
import torch
import heapq
from utils import *
os.environ['HF_HOME'] = os.path.join(
os.path.expanduser('~'), '.cache', "huggingface")
# load the system prompt from the file
with open("./system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
# System message
sys_msg = SystemMessage(content=system_prompt)
class BasicAgent:
tools: List[BaseTool] = [multiply,
multiply, add, subtract, divide, modulus,
wiki_search, web_search, arxiv_search,
python_repl, analyze_image,
date_filter, analyze_content,
step_by_step_reasoning, translate_text
]
def __init__(self, embeddings: HuggingFaceEmbeddings, vector_store: VectorStore, llm: BaseChatModel):
self.embedding_model = embeddings
self.vector_store = vector_store
ret = self.vector_store.as_retriever()
self.retriever = create_retriever_tool(
retriever=ret, #type: ignore
name="Question Search", #type: ignore
description="A tool to retrieve similar questions from a vector store." #type: ignore
)
self.llm = llm.bind_tools(self.tools)
self.graph = self.build_graph()
print("BasicAgent initialized.")
def __call__(self, question: str) -> str:
print(f"Agent received question (first 50 chars): {question[:50]}...")
# Search for similar content to enhance context - LIMIT TO 1 DOCUMENT ONLY
similar_docs = self.vector_store.similarity_search(question, k=1) # Reduced from 3 to 1
# Create enhanced context with relevant past information
enhanced_context = question
if (similar_docs):
context_additions = []
for doc in similar_docs:
# Extract relevant information from similar documents
content = doc.page_content
if "Question:" in content and "Final answer:" in content:
q = content.split("Question:")[1].split("Final answer:")[0].strip()
a = content.split("Final answer:")[1].split("Timestamp:", 1)[0].strip()
# Truncate long contexts
if len(q) > 200:
q = q[:200] + "..."
if len(a) > 300:
a = a[:300] + "..."
# Only add if it's not exactly the same question
if not question.lower() == q.lower():
context_additions.append(f"Related Q: {q}\nRelated A: {a}")
if context_additions:
enhanced_context = (
"Consider this relevant information first:\n\n" +
"\n\n".join(context_additions[:1]) + # Only use the first context addition
"\n\nNow answering this question: " + question
)
# Process with the graph
input_messages = [HumanMessage(content=enhanced_context)]
result = self.graph.invoke({"messages": input_messages})
answer = result["messages"][-1].content
# Store this Q&A pair for future reference
self._cache_result(question, answer)
print(f"Agent returning answer (first 50 chars): {answer[:50]}...")
return answer
def _cache_result(self, question: str, answer: str) -> None:
"""Cache the question and answer in the vector store"""
timestamp = datetime.now().isoformat()
content = f"Question: {question}\nFinal answer: {answer}\nTimestamp: {timestamp}"
# Create document with metadata
doc = Document(
page_content=content,
metadata={
"question": question,
"timestamp": timestamp,
"type": "qa_pair"
}
)
# Add to vector store
self.vector_store.add_documents([doc])
print(f"Cached new Q&A in vector store")
# Build graph function
def build_graph(self):
"""Build the graph with context enhancement"""
from langgraph.graph import END
def context_enhanced_generation(state: MessagesState):
"""Node that enhances context with relevant information"""
query = str(state["messages"][-1].content)
# Retrieve relevant information
similar_docs = self.vector_store.similarity_search(query, k=3)
# Extract relevant context
context = ""
if similar_docs:
context_pieces = []
for doc in similar_docs:
content = doc.page_content
# Extract the relevant parts
if "Question:" in content:
context_pieces.append(content)
if context_pieces:
context = "Relevant context:\n\n" + "\n\n".join(context_pieces) + "\n\n"
# Create enhanced messages
enhanced_messages = state["messages"].copy()
if context:
# Add context to system message if it exists, otherwise add a new one
system_message_found = False
for i, msg in enumerate(enhanced_messages):
if isinstance(msg, SystemMessage):
enhanced_messages[i] = SystemMessage(content=f"{msg.content}\n\n{context}")
system_message_found = True
break
if not system_message_found:
enhanced_messages.insert(0, SystemMessage(content=context))
# Process with LLM
response = self.llm.invoke(enhanced_messages)
return {"messages": state["messages"] + [response]}
# Tool handling node
tool_node = ToolNode(self.tools)
# Build graph with tool handling
builder = StateGraph(MessagesState)
builder.add_node("context_enhanced_generation", context_enhanced_generation)
builder.add_node("tools", tool_node)
# Connect nodes
builder.set_entry_point("context_enhanced_generation")
builder.add_conditional_edges(
"context_enhanced_generation",
tools_condition,
{
"tools": "tools",
END: END # Using END as the key instead of None
}
)
builder.add_edge("tools", "context_enhanced_generation")
return builder.compile()
@staticmethod
def get_llm(provider: str="groq") -> BaseChatModel:
# Load environment variables from .env file
if provider == "groq":
# Groq https://console.groq.com/docs/models
# optional : qwen-qwq-32b gemma2-9b-it
llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
elif provider == "huggingface":
# TODO: Add huggingface endpoint
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
model="Meta-DeepLearning/llama-2-7b-chat-hf",
temperature=0,
),
)
elif provider == "openai_local":
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
base_url="http://localhost:11432/v1", # default LM Studio endpoint
api_key="not-used", # required by interface but ignored #type: ignore
# model="mistral-nemo-instruct-2407",
model="meta-llama-3.1-8b-instruct",
temperature=0.2
)
elif provider == "openai":
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
model="gpt-4o",
temperature=0.2,
)
else:
raise ValueError(
"Invalid provider. Choose 'groq' or 'huggingface'.")
return llm
def manage_memory(self, max_documents: int = 1000, max_age_days: int = 30) -> None:
"""
Manage memory by pruning old or less useful entries from the vector store.
This implementation works with various vector store types, not just FAISS.
Args:
max_documents: Maximum number of documents to keep
max_age_days: Remove documents older than this many days
"""
print(f"Starting memory management...")
# Get all documents from the vector store
try:
# For vector stores that have a get_all_documents method
if hasattr(self.vector_store, "get_all_documents"):
all_docs = self.vector_store.get_all_documents()
all_ids = [doc.metadata.get("id", i) for i, doc in enumerate(all_docs)]
# For FAISS and similar implementations
elif hasattr(self.vector_store, "docstore") and hasattr(self.vector_store, "index_to_docstore_id"):
# Access docstore in a more robust way
if hasattr(self.vector_store.docstore, "docstore"):
all_ids = list(self.vector_store.index_to_docstore_id.values())
all_docs = []
for doc_id in all_ids:
doc = self.vector_store.docstore.search(doc_id)
if doc:
all_docs.append(doc)
else:
# Fallback for newer FAISS implementations
try:
all_docs = []
all_ids = []
# Get all index positions
for i in range(self.vector_store.index.ntotal):
# Map index position to document ID
if i in self.vector_store.index_to_docstore_id:
doc_id = self.vector_store.index_to_docstore_id[i]
doc = self.vector_store.docstore.search(doc_id)
if doc:
all_docs.append(doc)
all_ids.append(doc_id)
except Exception as e:
print(f"Error accessing FAISS documents: {e}")
all_docs = []
all_ids = []
else:
print("Warning: Vector store doesn't expose required attributes for memory management")
return
except Exception as e:
print(f"Error accessing vector store documents: {e}")
return
if not all_docs:
print("No documents found in vector store")
return
print(f"Retrieved {len(all_docs)} documents for scoring")
# Score each document based on recency and other factors
scored_docs = []
cutoff_date = datetime.now() - timedelta(days=max_age_days)
for i, doc in enumerate(all_docs):
doc_id = all_ids[i] if i < len(all_ids) else i
# Extract timestamp from content or metadata
timestamp = None
if hasattr(doc, "metadata") and doc.metadata and "timestamp" in doc.metadata:
try:
timestamp = datetime.fromisoformat(doc.metadata["timestamp"])
except (ValueError, TypeError):
pass
# If no timestamp in metadata, try to extract from content
if not timestamp and hasattr(doc, "page_content") and "Timestamp:" in doc.page_content:
try:
timestamp_str = doc.page_content.split("Timestamp:")[-1].strip().split('\n')[0]
timestamp = datetime.fromisoformat(timestamp_str)
except (ValueError, TypeError):
timestamp = datetime.now() - timedelta(days=max_age_days+1)
# If still no timestamp, use a default
if not timestamp:
timestamp = datetime.now() - timedelta(days=max_age_days+1)
# Calculate age score (newer is better)
age_factor = max(0.0, min(1.0, (timestamp - cutoff_date).total_seconds() /
(datetime.now() - cutoff_date).total_seconds()))
# Calculate importance score - could be based on various factors
importance_factor = 1.0
# If document has been accessed often, increase importance
if hasattr(doc, "metadata") and doc.metadata and "access_count" in doc.metadata:
importance_factor += min(1.0, doc.metadata["access_count"] / 10)
# Create combined score (higher = more valuable to keep)
total_score = (0.7 * age_factor) + (0.3 * importance_factor)
# Add to priority queue (negative for max-heap behavior)
heapq.heappush(scored_docs, (-total_score, i, doc))
# Select top documents to keep
docs_to_keep = []
for _ in range(min(max_documents, len(scored_docs))):
if scored_docs:
_, _, doc = heapq.heappop(scored_docs)
docs_to_keep.append(doc)
# Only rebuild if we're actually pruning some documents
if len(docs_to_keep) < len(all_docs):
print(f"Memory management: Keeping {len(docs_to_keep)} documents out of {len(all_docs)}")
# Create a new vector store with the same type as the current one
vector_store_type = type(self.vector_store)
# Different approaches based on vector store type
if hasattr(vector_store_type, "from_documents"):
# Most langchain vector stores support this method
new_vector_store = vector_store_type.from_documents(
docs_to_keep,
embedding=self.embedding_model
)
self.vector_store = new_vector_store
print(f"Vector store rebuilt with {len(docs_to_keep)} documents")
elif hasattr(vector_store_type, "from_texts"):
# For vector stores that use from_texts
texts = [doc.page_content for doc in docs_to_keep]
metadatas = [doc.metadata if hasattr(doc, "metadata") else {} for doc in docs_to_keep]
new_vector_store = vector_store_type.from_texts(
texts=texts,
embedding=self.embedding_model,
metadatas=metadatas
)
self.vector_store = new_vector_store
print(f"Vector store rebuilt with {len(docs_to_keep)} documents")
else:
print("Warning: Could not determine how to rebuild the vector store")
print(f"Vector store type: {vector_store_type.__name__}")
def capture_tool_result(self, tool_name: str, tool_input: str, tool_output: str) -> None:
"""
Capture knowledge gained from tool usage for future reference
Args:
tool_name: Name of the tool used
tool_input: Input/query sent to the tool
tool_output: Result returned by the tool
"""
# Format the content
timestamp = datetime.now().isoformat()
content = (
f"Tool Knowledge\n"
f"Tool: {tool_name}\n"
f"Query: {tool_input}\n"
f"Result: {tool_output}\n"
f"Timestamp: {timestamp}"
)
# Create document with metadata
doc = Document(
page_content=content,
metadata={
"type": "tool_knowledge",
"tool": tool_name,
"timestamp": timestamp,
"query": tool_input
}
)
# Add to vector store
self.vector_store.add_documents([doc])
print(f"Captured knowledge from tool '{tool_name}' in vector store")