Spaces:
Runtime error
Runtime error
import os | |
from typing import List, Dict, Any, Optional, Type, Callable | |
from datetime import datetime, timedelta | |
import heapq | |
import json | |
import torch | |
from langchain_core.tools import BaseTool | |
from langchain_core.language_models import BaseChatModel | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage | |
from langchain_core.vectorstores import VectorStore | |
from langchain_core.documents import Document | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain.tools.retriever import create_retriever_tool | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langgraph.graph import StateGraph, END | |
from langgraph.prebuilt import ( | |
ToolNode, | |
ToolInvocation, | |
agent_executor, | |
create_function_calling_executor, | |
AgentState, | |
MessageGraph | |
) | |
from langgraph.prebuilt.tool_executor import ToolExecutor, extract_tool_invocations | |
from langgraph.prebuilt.tool_nodes import get_default_tool_node_parser | |
class AdvancedToolAgent: | |
""" | |
An advanced agent with robust tool-calling capabilities using LangGraph. | |
Features enhanced memory management, context enrichment, and tool execution tracking. | |
""" | |
def __init__( | |
self, | |
embedding_model: HuggingFaceEmbeddings, | |
vector_store: VectorStore, | |
llm: BaseChatModel, | |
tools: Optional[List[BaseTool]] = None, | |
max_iterations: int = 10, | |
memory_threshold: float = 0.7 | |
): | |
""" | |
Initialize the agent with required components. | |
Args: | |
embedding_model: Model for embedding text | |
vector_store: Storage for agent memory | |
llm: Language model for agent reasoning | |
tools: List of tools accessible to the agent | |
max_iterations: Maximum number of tool calling iterations | |
memory_threshold: Threshold for deciding when to include memory context (0-1) | |
""" | |
self.embedding_model = embedding_model | |
self.vector_store = vector_store | |
self.llm = llm | |
self.tools = tools or [] | |
self.max_iterations = max_iterations | |
self.memory_threshold = memory_threshold | |
# Setup retriever for memory access | |
self.retriever = vector_store.as_retriever( | |
search_kwargs={"k": 3, "score_threshold": 0.75} | |
) | |
# Create memory retrieval tool | |
self.memory_tool = create_retriever_tool( | |
retriever=self.retriever, | |
name="memory_search", | |
description="Search the agent's memory for relevant past interactions and knowledge." | |
) | |
# Add memory tool to the agent's toolset | |
self.all_tools = self.tools + [self.memory_tool] | |
# Setup tool executor | |
self.tool_executor = ToolExecutor(self.all_tools) | |
# Build the agent's execution graph | |
self.agent_executor = self._build_agent_graph() | |
print(f"AdvancedToolAgent initialized with {len(self.all_tools)} tools") | |
def __call__(self, question: str) -> str: | |
""" | |
Process a question using the agent. | |
Args: | |
question: The user query to respond to | |
Returns: | |
The agent's response | |
""" | |
print(f"Agent received question: {question[:50]}..." if len(question) > 50 else question) | |
# Enrich context with relevant memory | |
enriched_input = self._enrich_context(question) | |
# Create initial state | |
initial_state = { | |
"messages": [HumanMessage(content=enriched_input)], | |
"tools": self.all_tools, | |
"tool_calls": [], | |
} | |
# Execute agent graph | |
final_state = self.agent_executor.invoke(initial_state) | |
# Extract the final response | |
final_message = final_state["messages"][-1] | |
answer = final_message.content | |
# Store this interaction in memory | |
self._store_interaction(question, answer, final_state.get("tool_calls", [])) | |
# Periodically manage memory | |
self._periodic_memory_management() | |
print(f"Agent returning answer: {answer[:50]}..." if len(answer) > 50 else answer) | |
return answer | |
def _build_agent_graph(self): | |
"""Build the LangGraph execution graph with enhanced tool calling""" | |
# Function for the agent to process messages and call tools | |
def agent_node(state: AgentState) -> AgentState: | |
"""Process messages and decide on next action""" | |
messages = state["messages"] | |
# Add system instructions with tool details | |
if not any(isinstance(msg, SystemMessage) for msg in messages): | |
system_prompt = self._create_system_prompt() | |
messages = [SystemMessage(content=system_prompt)] + messages | |
# Get response from LLM | |
response = self.llm.invoke(messages) | |
# Extract any tool calls | |
tool_calls = extract_tool_invocations( | |
response, | |
self.all_tools, | |
strict_mode=False, | |
) | |
# Update state | |
new_state = state.copy() | |
new_state["messages"] = messages + [response] | |
new_state["tool_calls"] = tool_calls | |
return new_state | |
# Function for executing tools | |
def tool_node(state: AgentState) -> AgentState: | |
"""Execute tools and add results to messages""" | |
# Get the tool calls from the state | |
tool_calls = state["tool_calls"] | |
# Execute each tool call | |
tool_results = [] | |
for tool_call in tool_calls: | |
try: | |
# Execute the tool | |
result = self.tool_executor.invoke(tool_call) | |
# Create a tool message with the result | |
tool_msg = ToolMessage( | |
content=str(result), | |
tool_call_id=tool_call.id, | |
name=tool_call.name, | |
) | |
tool_results.append(tool_msg) | |
# Track tool usage for memory | |
self._track_tool_usage(tool_call.name, tool_call.args, result) | |
except Exception as e: | |
# Handle tool execution errors | |
error_msg = f"Error executing tool {tool_call.name}: {str(e)}" | |
tool_msg = ToolMessage( | |
content=error_msg, | |
tool_call_id=tool_call.id, | |
name=tool_call.name, | |
) | |
tool_results.append(tool_msg) | |
# Update state with tool results | |
new_state = state.copy() | |
new_state["messages"] = state["messages"] + tool_results | |
new_state["tool_calls"] = [] | |
return new_state | |
# Create the graph | |
graph = StateGraph(AgentState) | |
# Add nodes | |
graph.add_node("agent", agent_node) | |
graph.add_node("tools", tool_node) | |
# Set the entry point | |
graph.set_entry_point("agent") | |
# Add edges | |
graph.add_conditional_edges( | |
"agent", | |
lambda state: "tools" if state["tool_calls"] else END, | |
{ | |
"tools": "tools", | |
END: END, | |
} | |
) | |
graph.add_edge("tools", "agent") | |
# Set max iterations to prevent infinite loops | |
return graph.compile(max_iterations=self.max_iterations) | |
def _create_system_prompt(self) -> str: | |
"""Create a system prompt with tool instructions""" | |
tool_descriptions = "\n\n".join([ | |
f"Tool {i+1}: {tool.name}\n" | |
f"Description: {tool.description}\n" | |
f"Args: {json.dumps(tool.args, indent=2) if hasattr(tool, 'args') else 'No arguments required'}" | |
for i, tool in enumerate(self.all_tools) | |
]) | |
return f"""You are an advanced AI assistant with access to various tools. | |
When a user asks a question, use your knowledge and the available tools to provide | |
accurate and helpful responses. | |
AVAILABLE TOOLS: | |
{tool_descriptions} | |
INSTRUCTIONS FOR TOOL USAGE: | |
1. When you need information that requires a tool, call the appropriate tool. | |
2. Format tool calls clearly by specifying the tool name and inputs. | |
3. Wait for tool results before providing final answers. | |
4. Use tools only when necessary - if you can answer directly, do so. | |
5. If a tool fails, try a different approach or tool. | |
6. Always explain your reasoning step by step. | |
Remember to be helpful, accurate, and concise in your responses. | |
""" | |
def _enrich_context(self, query: str) -> str: | |
"""Enrich the input query with relevant context from memory""" | |
# Search for similar content | |
similar_docs = self.vector_store.similarity_search( | |
query, | |
k=2, # Limit to 2 most relevant documents | |
fetch_k=5 # Consider 5 candidates | |
) | |
# Only use memory if relevance is high enough | |
if not similar_docs or len(similar_docs) == 0: | |
return query | |
# Build enhanced context | |
context_additions = [] | |
for doc in similar_docs: | |
content = doc.page_content | |
# Extract different types of memory | |
if "Question:" in content and "Final answer:" in content: | |
# Q&A memory | |
q = content.split("Question:")[1].split("Final answer:")[0].strip() | |
a = content.split("Final answer:")[1].split("Timestamp:", 1)[0].strip() | |
# Only add if it's not too similar to current question | |
if not self._is_similar_question(query, q, threshold=0.85): | |
context_additions.append(f"Related Q: {q}\nRelated A: {a}") | |
elif "Tool Knowledge" in content: | |
# Tool usage memory | |
tool_name = content.split("Tool:")[1].split("Query:")[0].strip() | |
tool_result = content.split("Result:")[1].split("Timestamp:")[0].strip() | |
context_additions.append( | |
f"From prior tool use ({tool_name}): {tool_result[:200]}" | |
) | |
# Only add context if we have relevant information | |
if context_additions: | |
return ( | |
"Consider this relevant information first:\n\n" + | |
"\n\n".join(context_additions[:2]) + # Limit to 2 pieces of context | |
"\n\nNow answering this question: " + query | |
) | |
else: | |
return query | |
def _is_similar_question(self, query1: str, query2: str, threshold: float = 0.8) -> bool: | |
"""Check if two questions are semantically similar using embeddings""" | |
# Get embeddings for both queries | |
if hasattr(self.embedding_model, 'embed_query'): | |
emb1 = self.embedding_model.embed_query(query1) | |
emb2 = self.embedding_model.embed_query(query2) | |
# Calculate cosine similarity | |
similarity = self._cosine_similarity(emb1, emb2) | |
return similarity > threshold | |
return False | |
def _cosine_similarity(v1, v2): | |
"""Calculate cosine similarity between vectors""" | |
dot_product = sum(x * y for x, y in zip(v1, v2)) | |
magnitude1 = sum(x * x for x in v1) ** 0.5 | |
magnitude2 = sum(x * x for x in v2) ** 0.5 | |
if magnitude1 * magnitude2 == 0: | |
return 0 | |
return dot_product / (magnitude1 * magnitude2) | |
def _store_interaction(self, question: str, answer: str, tool_calls: List[dict]) -> None: | |
"""Store the interaction in vector memory""" | |
timestamp = datetime.now().isoformat() | |
# Format tools used | |
tools_used = [] | |
for tool_call in tool_calls: | |
if isinstance(tool_call, dict) and 'name' in tool_call: | |
tools_used.append(tool_call['name']) | |
elif hasattr(tool_call, 'name'): | |
tools_used.append(tool_call.name) | |
tools_str = ", ".join(tools_used) if tools_used else "None" | |
# Create content | |
content = ( | |
f"Question: {question}\n" | |
f"Tools Used: {tools_str}\n" | |
f"Final answer: {answer}\n" | |
f"Timestamp: {timestamp}" | |
) | |
# Create document with metadata | |
doc = Document( | |
page_content=content, | |
metadata={ | |
"question": question, | |
"timestamp": timestamp, | |
"type": "qa_pair", | |
"tools_used": tools_str | |
} | |
) | |
# Add to vector store | |
self.vector_store.add_documents([doc]) | |
def _track_tool_usage(self, tool_name: str, tool_input: Any, tool_output: Any) -> None: | |
"""Track tool usage for future reference""" | |
timestamp = datetime.now().isoformat() | |
# Format the content | |
content = ( | |
f"Tool Knowledge\n" | |
f"Tool: {tool_name}\n" | |
f"Query: {str(tool_input)}\n" | |
f"Result: {str(tool_output)}\n" | |
f"Timestamp: {timestamp}" | |
) | |
# Create document with metadata | |
doc = Document( | |
page_content=content, | |
metadata={ | |
"type": "tool_knowledge", | |
"tool": tool_name, | |
"timestamp": timestamp | |
} | |
) | |
# Add to vector store | |
self.vector_store.add_documents([doc]) | |
def _periodic_memory_management(self, | |
check_frequency: int = 10, | |
max_documents: int = 1000, | |
max_age_days: int = 30) -> None: | |
"""Periodically manage memory to prevent unbounded growth""" | |
# Simple probabilistic check to avoid running this too often | |
if hash(datetime.now().isoformat()) % check_frequency != 0: | |
return | |
self.manage_memory(max_documents, max_age_days) | |
def manage_memory(self, max_documents: int = 1000, max_age_days: int = 30) -> None: | |
""" | |
Manage memory by pruning old or less useful entries from the vector store. | |
Args: | |
max_documents: Maximum number of documents to keep | |
max_age_days: Remove documents older than this many days | |
""" | |
print(f"Starting memory management...") | |
# Get all documents from the vector store | |
try: | |
# For vector stores that have a get_all_documents method | |
if hasattr(self.vector_store, "get_all_documents"): | |
all_docs = self.vector_store.get_all_documents() | |
all_ids = [doc.metadata.get("id", i) for i, doc in enumerate(all_docs)] | |
# For other vector store implementations | |
else: | |
print("Warning: Vector store doesn't expose required attributes for memory management") | |
return | |
except Exception as e: | |
print(f"Error accessing vector store documents: {e}") | |
return | |
if not all_docs: | |
print("No documents found in vector store") | |
return | |
print(f"Retrieved {len(all_docs)} documents for scoring") | |
# Score each document based on recency, importance and relevance | |
scored_docs = [] | |
cutoff_date = datetime.now() - timedelta(days=max_age_days) | |
for i, doc in enumerate(all_docs): | |
doc_id = all_ids[i] if i < len(all_ids) else i | |
# Extract timestamp from content or metadata | |
timestamp = None | |
if hasattr(doc, "metadata") and doc.metadata and "timestamp" in doc.metadata: | |
try: | |
timestamp = datetime.fromisoformat(doc.metadata["timestamp"]) | |
except (ValueError, TypeError): | |
pass | |
# If no timestamp in metadata, try to extract from content | |
if not timestamp and hasattr(doc, "page_content") and "Timestamp:" in doc.page_content: | |
try: | |
timestamp_str = doc.page_content.split("Timestamp:")[-1].strip().split('\n')[0] | |
timestamp = datetime.fromisoformat(timestamp_str) | |
except (ValueError, TypeError): | |
timestamp = datetime.now() - timedelta(days=max_age_days+1) | |
# If still no timestamp, use a default | |
if not timestamp: | |
timestamp = datetime.now() - timedelta(days=max_age_days+1) | |
# Calculate age score (newer is better) | |
age_factor = max(0.0, min(1.0, (timestamp - cutoff_date).total_seconds() / | |
(datetime.now() - cutoff_date).total_seconds())) | |
# Calculate importance score based on document type and access frequency | |
importance_factor = 1.0 | |
# Tool knowledge is more valuable | |
if hasattr(doc, "metadata") and doc.metadata and doc.metadata.get("type") == "tool_knowledge": | |
importance_factor += 0.5 | |
# If document has been accessed often, increase importance | |
if hasattr(doc, "metadata") and doc.metadata and "access_count" in doc.metadata: | |
importance_factor += min(1.0, doc.metadata["access_count"] / 10) | |
# If document contains references to complex tools, prioritize it | |
if hasattr(doc, "page_content"): | |
complex_tools = ["web_search", "python_repl", "analyze_image", "arxiv_search"] | |
if any(tool in doc.page_content for tool in complex_tools): | |
importance_factor += 0.3 | |
# Create combined score (higher = more valuable to keep) | |
total_score = (0.6 * age_factor) + (0.4 * importance_factor) | |
# Add to priority queue (negative for max-heap behavior) | |
heapq.heappush(scored_docs, (-total_score, i, doc)) | |
# Select top documents to keep | |
docs_to_keep = [] | |
for _ in range(min(max_documents, len(scored_docs))): | |
if scored_docs: | |
_, _, doc = heapq.heappop(scored_docs) | |
docs_to_keep.append(doc) | |
# Only rebuild if we're actually pruning some documents | |
if len(docs_to_keep) < len(all_docs): | |
print(f"Memory management: Keeping {len(docs_to_keep)} documents out of {len(all_docs)}") | |
# Create a new vector store with the same type as the current one | |
vector_store_type = type(self.vector_store) | |
# Different approaches based on vector store type | |
if hasattr(vector_store_type, "from_documents"): | |
# Most langchain vector stores support this method | |
new_vector_store = vector_store_type.from_documents( | |
docs_to_keep, | |
embedding=self.embedding_model | |
) | |
self.vector_store = new_vector_store | |
print(f"Vector store rebuilt with {len(docs_to_keep)} documents") | |
elif hasattr(vector_store_type, "from_texts"): | |
# For vector stores that use from_texts | |
texts = [doc.page_content for doc in docs_to_keep] | |
metadatas = [doc.metadata if hasattr(doc, "metadata") else {} for doc in docs_to_keep] | |
new_vector_store = vector_store_type.from_texts( | |
texts=texts, | |
embedding=self.embedding_model, | |
metadatas=metadatas | |
) | |
self.vector_store = new_vector_store | |
print(f"Vector store rebuilt with {len(docs_to_keep)} documents") | |
else: | |
print("Warning: Could not determine how to rebuild the vector store") | |
print(f"Vector store type: {vector_store_type.__name__}") | |
# Example usage | |
if __name__ == "__main__": | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_chroma import Chroma | |
from langchain_groq import ChatGroq | |
from basic_tools import multiply, add, subtract, divide, wiki_search, web_search | |
# Initialize embeddings | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-mpnet-base-v2", | |
model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"} | |
) | |
# Initialize vector store | |
vector_store = Chroma( | |
embedding_function=embeddings, | |
collection_name="advanced_agent_memory" | |
) | |
# Initialize LLM | |
llm = ChatGroq(model="qwen-qwq-32b", temperature=0) | |
# Define tools | |
tools = [multiply, add, subtract, divide, wiki_search, web_search] | |
# Create agent | |
agent = AdvancedToolAgent( | |
embedding_model=embeddings, | |
vector_store=vector_store, | |
llm=llm, | |
tools=tools | |
) | |
# Test the agent | |
response = agent("What is the population of France multiplied by 2?") | |
print(f"Response: {response}") |