Spaces:
Runtime error
Runtime error
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings | |
from langchain.tools.retriever import create_retriever_tool | |
from langchain_core.tools import BaseTool | |
from langgraph.graph import START, StateGraph, MessagesState, END | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
from langchain.vectorstores import VectorStore | |
from langchain_core.language_models import BaseChatModel | |
from langgraph.prebuilt import tools_condition | |
from langgraph.prebuilt import ToolNode | |
# from langchain_community.vectorstores import Chroma | |
from langchain_core.documents import Document | |
from langchain_groq import ChatGroq | |
from basic_tools import * | |
from typing import List | |
import numpy as np | |
from datetime import datetime, timedelta | |
from sentence_transformers import SentenceTransformer | |
import torch | |
import heapq | |
from utils import * | |
os.environ['HF_HOME'] = os.path.join( | |
os.path.expanduser('~'), '.cache', "huggingface") | |
# load the system prompt from the file | |
with open("./system_prompt.txt", "r", encoding="utf-8") as f: | |
system_prompt = f.read() | |
# System message | |
sys_msg = SystemMessage(content=system_prompt) | |
class BasicAgent: | |
tools: List[BaseTool] = [multiply, | |
multiply, add, subtract, divide, modulus, | |
wiki_search, web_search, arxiv_search, | |
python_repl, analyze_image, | |
date_filter, analyze_content, | |
step_by_step_reasoning, translate_text | |
] | |
def __init__(self, embeddings: HuggingFaceEmbeddings, vector_store: VectorStore, llm: BaseChatModel): | |
self.embedding_model = embeddings | |
self.vector_store = vector_store | |
ret = self.vector_store.as_retriever() | |
self.retriever = create_retriever_tool( | |
retriever=ret, #type: ignore | |
name="Question Search", #type: ignore | |
description="A tool to retrieve similar questions from a vector store." #type: ignore | |
) | |
self.llm = llm.bind_tools(self.tools) | |
self.graph = self.build_graph() | |
print("BasicAgent initialized.") | |
def __call__(self, question: str) -> str: | |
print(f"Agent received question (first 50 chars): {question[:50]}...") | |
# Search for similar content to enhance context - LIMIT TO 1 DOCUMENT ONLY | |
similar_docs = self.vector_store.similarity_search(question, k=1) # Reduced from 3 to 1 | |
# Create enhanced context with relevant past information | |
enhanced_context = question | |
if (similar_docs): | |
context_additions = [] | |
for doc in similar_docs: | |
# Extract relevant information from similar documents | |
content = doc.page_content | |
if "Question:" in content and "Final answer:" in content: | |
q = content.split("Question:")[1].split("Final answer:")[0].strip() | |
a = content.split("Final answer:")[1].split("Timestamp:", 1)[0].strip() | |
# Truncate long contexts | |
if len(q) > 200: | |
q = q[:200] + "..." | |
if len(a) > 300: | |
a = a[:300] + "..." | |
# Only add if it's not exactly the same question | |
if not question.lower() == q.lower(): | |
context_additions.append(f"Related Q: {q}\nRelated A: {a}") | |
if context_additions: | |
enhanced_context = ( | |
"Consider this relevant information first:\n\n" + | |
"\n\n".join(context_additions[:1]) + # Only use the first context addition | |
"\n\nNow answering this question: " + question | |
) | |
# Process with the graph | |
input_messages = [HumanMessage(content=enhanced_context)] | |
result = self.graph.invoke({"messages": input_messages}) | |
answer = result["messages"][-1].content | |
# Store this Q&A pair for future reference | |
self._cache_result(question, answer) | |
print(f"Agent returning answer (first 50 chars): {answer[:50]}...") | |
return answer | |
def _cache_result(self, question: str, answer: str) -> None: | |
"""Cache the question and answer in the vector store""" | |
timestamp = datetime.now().isoformat() | |
content = f"Question: {question}\nFinal answer: {answer}\nTimestamp: {timestamp}" | |
# Create document with metadata | |
doc = Document( | |
page_content=content, | |
metadata={ | |
"question": question, | |
"timestamp": timestamp, | |
"type": "qa_pair" | |
} | |
) | |
# Add to vector store | |
self.vector_store.add_documents([doc]) | |
print(f"Cached new Q&A in vector store") | |
# Build graph function | |
def build_graph(self): | |
"""Build the graph with context enhancement""" | |
from langgraph.graph import END | |
def context_enhanced_generation(state: MessagesState): | |
"""Node that enhances context with relevant information""" | |
query = str(state["messages"][-1].content) | |
# Retrieve relevant information | |
similar_docs = self.vector_store.similarity_search(query, k=3) | |
# Extract relevant context | |
context = "" | |
if similar_docs: | |
context_pieces = [] | |
for doc in similar_docs: | |
content = doc.page_content | |
# Extract the relevant parts | |
if "Question:" in content: | |
context_pieces.append(content) | |
if context_pieces: | |
context = "Relevant context:\n\n" + "\n\n".join(context_pieces) + "\n\n" | |
# Create enhanced messages | |
enhanced_messages = state["messages"].copy() | |
if context: | |
# Add context to system message if it exists, otherwise add a new one | |
system_message_found = False | |
for i, msg in enumerate(enhanced_messages): | |
if isinstance(msg, SystemMessage): | |
enhanced_messages[i] = SystemMessage(content=f"{msg.content}\n\n{context}") | |
system_message_found = True | |
break | |
if not system_message_found: | |
enhanced_messages.insert(0, SystemMessage(content=context)) | |
# Process with LLM | |
response = self.llm.invoke(enhanced_messages) | |
return {"messages": state["messages"] + [response]} | |
# Tool handling node | |
tool_node = ToolNode(self.tools) | |
# Build graph with tool handling | |
builder = StateGraph(MessagesState) | |
builder.add_node("context_enhanced_generation", context_enhanced_generation) | |
builder.add_node("tools", tool_node) | |
# Connect nodes | |
builder.set_entry_point("context_enhanced_generation") | |
builder.add_conditional_edges( | |
"context_enhanced_generation", | |
tools_condition, | |
{ | |
"tools": "tools", | |
END: END # Using END as the key instead of None | |
} | |
) | |
builder.add_edge("tools", "context_enhanced_generation") | |
return builder.compile() | |
def get_llm(provider: str="groq") -> BaseChatModel: | |
# Load environment variables from .env file | |
if provider == "groq": | |
# Groq https://console.groq.com/docs/models | |
# optional : qwen-qwq-32b gemma2-9b-it | |
llm = ChatGroq(model="qwen-qwq-32b", temperature=0) | |
elif provider == "huggingface": | |
# TODO: Add huggingface endpoint | |
llm = ChatHuggingFace( | |
llm=HuggingFaceEndpoint( | |
model="Meta-DeepLearning/llama-2-7b-chat-hf", | |
temperature=0, | |
), | |
) | |
elif provider == "openai_local": | |
from langchain_openai import ChatOpenAI | |
llm = ChatOpenAI( | |
base_url="http://localhost:11432/v1", # default LM Studio endpoint | |
api_key="not-used", # required by interface but ignored #type: ignore | |
# model="mistral-nemo-instruct-2407", | |
model="meta-llama-3.1-8b-instruct", | |
temperature=0.2 | |
) | |
elif provider == "openai": | |
from langchain_openai import ChatOpenAI | |
llm = ChatOpenAI( | |
model="gpt-4o", | |
temperature=0.2, | |
) | |
else: | |
raise ValueError( | |
"Invalid provider. Choose 'groq' or 'huggingface'.") | |
return llm | |
def manage_memory(self, max_documents: int = 1000, max_age_days: int = 30) -> None: | |
""" | |
Manage memory by pruning old or less useful entries from the vector store. | |
This implementation works with various vector store types, not just FAISS. | |
Args: | |
max_documents: Maximum number of documents to keep | |
max_age_days: Remove documents older than this many days | |
""" | |
print(f"Starting memory management...") | |
# Get all documents from the vector store | |
try: | |
# For vector stores that have a get_all_documents method | |
if hasattr(self.vector_store, "get_all_documents"): | |
all_docs = self.vector_store.get_all_documents() | |
all_ids = [doc.metadata.get("id", i) for i, doc in enumerate(all_docs)] | |
# For FAISS and similar implementations | |
elif hasattr(self.vector_store, "docstore") and hasattr(self.vector_store, "index_to_docstore_id"): | |
# Access docstore in a more robust way | |
if hasattr(self.vector_store.docstore, "docstore"): | |
all_ids = list(self.vector_store.index_to_docstore_id.values()) | |
all_docs = [] | |
for doc_id in all_ids: | |
doc = self.vector_store.docstore.search(doc_id) | |
if doc: | |
all_docs.append(doc) | |
else: | |
# Fallback for newer FAISS implementations | |
try: | |
all_docs = [] | |
all_ids = [] | |
# Get all index positions | |
for i in range(self.vector_store.index.ntotal): | |
# Map index position to document ID | |
if i in self.vector_store.index_to_docstore_id: | |
doc_id = self.vector_store.index_to_docstore_id[i] | |
doc = self.vector_store.docstore.search(doc_id) | |
if doc: | |
all_docs.append(doc) | |
all_ids.append(doc_id) | |
except Exception as e: | |
print(f"Error accessing FAISS documents: {e}") | |
all_docs = [] | |
all_ids = [] | |
else: | |
print("Warning: Vector store doesn't expose required attributes for memory management") | |
return | |
except Exception as e: | |
print(f"Error accessing vector store documents: {e}") | |
return | |
if not all_docs: | |
print("No documents found in vector store") | |
return | |
print(f"Retrieved {len(all_docs)} documents for scoring") | |
# Score each document based on recency and other factors | |
scored_docs = [] | |
cutoff_date = datetime.now() - timedelta(days=max_age_days) | |
for i, doc in enumerate(all_docs): | |
doc_id = all_ids[i] if i < len(all_ids) else i | |
# Extract timestamp from content or metadata | |
timestamp = None | |
if hasattr(doc, "metadata") and doc.metadata and "timestamp" in doc.metadata: | |
try: | |
timestamp = datetime.fromisoformat(doc.metadata["timestamp"]) | |
except (ValueError, TypeError): | |
pass | |
# If no timestamp in metadata, try to extract from content | |
if not timestamp and hasattr(doc, "page_content") and "Timestamp:" in doc.page_content: | |
try: | |
timestamp_str = doc.page_content.split("Timestamp:")[-1].strip().split('\n')[0] | |
timestamp = datetime.fromisoformat(timestamp_str) | |
except (ValueError, TypeError): | |
timestamp = datetime.now() - timedelta(days=max_age_days+1) | |
# If still no timestamp, use a default | |
if not timestamp: | |
timestamp = datetime.now() - timedelta(days=max_age_days+1) | |
# Calculate age score (newer is better) | |
age_factor = max(0.0, min(1.0, (timestamp - cutoff_date).total_seconds() / | |
(datetime.now() - cutoff_date).total_seconds())) | |
# Calculate importance score - could be based on various factors | |
importance_factor = 1.0 | |
# If document has been accessed often, increase importance | |
if hasattr(doc, "metadata") and doc.metadata and "access_count" in doc.metadata: | |
importance_factor += min(1.0, doc.metadata["access_count"] / 10) | |
# Create combined score (higher = more valuable to keep) | |
total_score = (0.7 * age_factor) + (0.3 * importance_factor) | |
# Add to priority queue (negative for max-heap behavior) | |
heapq.heappush(scored_docs, (-total_score, i, doc)) | |
# Select top documents to keep | |
docs_to_keep = [] | |
for _ in range(min(max_documents, len(scored_docs))): | |
if scored_docs: | |
_, _, doc = heapq.heappop(scored_docs) | |
docs_to_keep.append(doc) | |
# Only rebuild if we're actually pruning some documents | |
if len(docs_to_keep) < len(all_docs): | |
print(f"Memory management: Keeping {len(docs_to_keep)} documents out of {len(all_docs)}") | |
# Create a new vector store with the same type as the current one | |
vector_store_type = type(self.vector_store) | |
# Different approaches based on vector store type | |
if hasattr(vector_store_type, "from_documents"): | |
# Most langchain vector stores support this method | |
new_vector_store = vector_store_type.from_documents( | |
docs_to_keep, | |
embedding=self.embedding_model | |
) | |
self.vector_store = new_vector_store | |
print(f"Vector store rebuilt with {len(docs_to_keep)} documents") | |
elif hasattr(vector_store_type, "from_texts"): | |
# For vector stores that use from_texts | |
texts = [doc.page_content for doc in docs_to_keep] | |
metadatas = [doc.metadata if hasattr(doc, "metadata") else {} for doc in docs_to_keep] | |
new_vector_store = vector_store_type.from_texts( | |
texts=texts, | |
embedding=self.embedding_model, | |
metadatas=metadatas | |
) | |
self.vector_store = new_vector_store | |
print(f"Vector store rebuilt with {len(docs_to_keep)} documents") | |
else: | |
print("Warning: Could not determine how to rebuild the vector store") | |
print(f"Vector store type: {vector_store_type.__name__}") | |
def capture_tool_result(self, tool_name: str, tool_input: str, tool_output: str) -> None: | |
""" | |
Capture knowledge gained from tool usage for future reference | |
Args: | |
tool_name: Name of the tool used | |
tool_input: Input/query sent to the tool | |
tool_output: Result returned by the tool | |
""" | |
# Format the content | |
timestamp = datetime.now().isoformat() | |
content = ( | |
f"Tool Knowledge\n" | |
f"Tool: {tool_name}\n" | |
f"Query: {tool_input}\n" | |
f"Result: {tool_output}\n" | |
f"Timestamp: {timestamp}" | |
) | |
# Create document with metadata | |
doc = Document( | |
page_content=content, | |
metadata={ | |
"type": "tool_knowledge", | |
"tool": tool_name, | |
"timestamp": timestamp, | |
"query": tool_input | |
} | |
) | |
# Add to vector store | |
self.vector_store.add_documents([doc]) | |
print(f"Captured knowledge from tool '{tool_name}' in vector store") | |