import os from typing import List, Dict, Any, Optional, Type, Callable from datetime import datetime, timedelta import heapq import json import torch from langchain_core.tools import BaseTool from langchain_core.language_models import BaseChatModel from langchain_core.prompts import ChatPromptTemplate from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage from langchain_core.vectorstores import VectorStore from langchain_core.documents import Document from langchain_core.output_parsers import StrOutputParser from langchain.tools.retriever import create_retriever_tool from langchain_huggingface import HuggingFaceEmbeddings from langgraph.graph import StateGraph, END from langgraph.prebuilt import ( ToolNode, ToolInvocation, agent_executor, create_function_calling_executor, AgentState, MessageGraph ) from langgraph.prebuilt.tool_executor import ToolExecutor, extract_tool_invocations from langgraph.prebuilt.tool_nodes import get_default_tool_node_parser class AdvancedToolAgent: """ An advanced agent with robust tool-calling capabilities using LangGraph. Features enhanced memory management, context enrichment, and tool execution tracking. """ def __init__( self, embedding_model: HuggingFaceEmbeddings, vector_store: VectorStore, llm: BaseChatModel, tools: Optional[List[BaseTool]] = None, max_iterations: int = 10, memory_threshold: float = 0.7 ): """ Initialize the agent with required components. Args: embedding_model: Model for embedding text vector_store: Storage for agent memory llm: Language model for agent reasoning tools: List of tools accessible to the agent max_iterations: Maximum number of tool calling iterations memory_threshold: Threshold for deciding when to include memory context (0-1) """ self.embedding_model = embedding_model self.vector_store = vector_store self.llm = llm self.tools = tools or [] self.max_iterations = max_iterations self.memory_threshold = memory_threshold # Setup retriever for memory access self.retriever = vector_store.as_retriever( search_kwargs={"k": 3, "score_threshold": 0.75} ) # Create memory retrieval tool self.memory_tool = create_retriever_tool( retriever=self.retriever, name="memory_search", description="Search the agent's memory for relevant past interactions and knowledge." ) # Add memory tool to the agent's toolset self.all_tools = self.tools + [self.memory_tool] # Setup tool executor self.tool_executor = ToolExecutor(self.all_tools) # Build the agent's execution graph self.agent_executor = self._build_agent_graph() print(f"AdvancedToolAgent initialized with {len(self.all_tools)} tools") def __call__(self, question: str) -> str: """ Process a question using the agent. Args: question: The user query to respond to Returns: The agent's response """ print(f"Agent received question: {question[:50]}..." if len(question) > 50 else question) # Enrich context with relevant memory enriched_input = self._enrich_context(question) # Create initial state initial_state = { "messages": [HumanMessage(content=enriched_input)], "tools": self.all_tools, "tool_calls": [], } # Execute agent graph final_state = self.agent_executor.invoke(initial_state) # Extract the final response final_message = final_state["messages"][-1] answer = final_message.content # Store this interaction in memory self._store_interaction(question, answer, final_state.get("tool_calls", [])) # Periodically manage memory self._periodic_memory_management() print(f"Agent returning answer: {answer[:50]}..." if len(answer) > 50 else answer) return answer def _build_agent_graph(self): """Build the LangGraph execution graph with enhanced tool calling""" # Function for the agent to process messages and call tools def agent_node(state: AgentState) -> AgentState: """Process messages and decide on next action""" messages = state["messages"] # Add system instructions with tool details if not any(isinstance(msg, SystemMessage) for msg in messages): system_prompt = self._create_system_prompt() messages = [SystemMessage(content=system_prompt)] + messages # Get response from LLM response = self.llm.invoke(messages) # Extract any tool calls tool_calls = extract_tool_invocations( response, self.all_tools, strict_mode=False, ) # Update state new_state = state.copy() new_state["messages"] = messages + [response] new_state["tool_calls"] = tool_calls return new_state # Function for executing tools def tool_node(state: AgentState) -> AgentState: """Execute tools and add results to messages""" # Get the tool calls from the state tool_calls = state["tool_calls"] # Execute each tool call tool_results = [] for tool_call in tool_calls: try: # Execute the tool result = self.tool_executor.invoke(tool_call) # Create a tool message with the result tool_msg = ToolMessage( content=str(result), tool_call_id=tool_call.id, name=tool_call.name, ) tool_results.append(tool_msg) # Track tool usage for memory self._track_tool_usage(tool_call.name, tool_call.args, result) except Exception as e: # Handle tool execution errors error_msg = f"Error executing tool {tool_call.name}: {str(e)}" tool_msg = ToolMessage( content=error_msg, tool_call_id=tool_call.id, name=tool_call.name, ) tool_results.append(tool_msg) # Update state with tool results new_state = state.copy() new_state["messages"] = state["messages"] + tool_results new_state["tool_calls"] = [] return new_state # Create the graph graph = StateGraph(AgentState) # Add nodes graph.add_node("agent", agent_node) graph.add_node("tools", tool_node) # Set the entry point graph.set_entry_point("agent") # Add edges graph.add_conditional_edges( "agent", lambda state: "tools" if state["tool_calls"] else END, { "tools": "tools", END: END, } ) graph.add_edge("tools", "agent") # Set max iterations to prevent infinite loops return graph.compile(max_iterations=self.max_iterations) def _create_system_prompt(self) -> str: """Create a system prompt with tool instructions""" tool_descriptions = "\n\n".join([ f"Tool {i+1}: {tool.name}\n" f"Description: {tool.description}\n" f"Args: {json.dumps(tool.args, indent=2) if hasattr(tool, 'args') else 'No arguments required'}" for i, tool in enumerate(self.all_tools) ]) return f"""You are an advanced AI assistant with access to various tools. When a user asks a question, use your knowledge and the available tools to provide accurate and helpful responses. AVAILABLE TOOLS: {tool_descriptions} INSTRUCTIONS FOR TOOL USAGE: 1. When you need information that requires a tool, call the appropriate tool. 2. Format tool calls clearly by specifying the tool name and inputs. 3. Wait for tool results before providing final answers. 4. Use tools only when necessary - if you can answer directly, do so. 5. If a tool fails, try a different approach or tool. 6. Always explain your reasoning step by step. Remember to be helpful, accurate, and concise in your responses. """ def _enrich_context(self, query: str) -> str: """Enrich the input query with relevant context from memory""" # Search for similar content similar_docs = self.vector_store.similarity_search( query, k=2, # Limit to 2 most relevant documents fetch_k=5 # Consider 5 candidates ) # Only use memory if relevance is high enough if not similar_docs or len(similar_docs) == 0: return query # Build enhanced context context_additions = [] for doc in similar_docs: content = doc.page_content # Extract different types of memory if "Question:" in content and "Final answer:" in content: # Q&A memory q = content.split("Question:")[1].split("Final answer:")[0].strip() a = content.split("Final answer:")[1].split("Timestamp:", 1)[0].strip() # Only add if it's not too similar to current question if not self._is_similar_question(query, q, threshold=0.85): context_additions.append(f"Related Q: {q}\nRelated A: {a}") elif "Tool Knowledge" in content: # Tool usage memory tool_name = content.split("Tool:")[1].split("Query:")[0].strip() tool_result = content.split("Result:")[1].split("Timestamp:")[0].strip() context_additions.append( f"From prior tool use ({tool_name}): {tool_result[:200]}" ) # Only add context if we have relevant information if context_additions: return ( "Consider this relevant information first:\n\n" + "\n\n".join(context_additions[:2]) + # Limit to 2 pieces of context "\n\nNow answering this question: " + query ) else: return query def _is_similar_question(self, query1: str, query2: str, threshold: float = 0.8) -> bool: """Check if two questions are semantically similar using embeddings""" # Get embeddings for both queries if hasattr(self.embedding_model, 'embed_query'): emb1 = self.embedding_model.embed_query(query1) emb2 = self.embedding_model.embed_query(query2) # Calculate cosine similarity similarity = self._cosine_similarity(emb1, emb2) return similarity > threshold return False @staticmethod def _cosine_similarity(v1, v2): """Calculate cosine similarity between vectors""" dot_product = sum(x * y for x, y in zip(v1, v2)) magnitude1 = sum(x * x for x in v1) ** 0.5 magnitude2 = sum(x * x for x in v2) ** 0.5 if magnitude1 * magnitude2 == 0: return 0 return dot_product / (magnitude1 * magnitude2) def _store_interaction(self, question: str, answer: str, tool_calls: List[dict]) -> None: """Store the interaction in vector memory""" timestamp = datetime.now().isoformat() # Format tools used tools_used = [] for tool_call in tool_calls: if isinstance(tool_call, dict) and 'name' in tool_call: tools_used.append(tool_call['name']) elif hasattr(tool_call, 'name'): tools_used.append(tool_call.name) tools_str = ", ".join(tools_used) if tools_used else "None" # Create content content = ( f"Question: {question}\n" f"Tools Used: {tools_str}\n" f"Final answer: {answer}\n" f"Timestamp: {timestamp}" ) # Create document with metadata doc = Document( page_content=content, metadata={ "question": question, "timestamp": timestamp, "type": "qa_pair", "tools_used": tools_str } ) # Add to vector store self.vector_store.add_documents([doc]) def _track_tool_usage(self, tool_name: str, tool_input: Any, tool_output: Any) -> None: """Track tool usage for future reference""" timestamp = datetime.now().isoformat() # Format the content content = ( f"Tool Knowledge\n" f"Tool: {tool_name}\n" f"Query: {str(tool_input)}\n" f"Result: {str(tool_output)}\n" f"Timestamp: {timestamp}" ) # Create document with metadata doc = Document( page_content=content, metadata={ "type": "tool_knowledge", "tool": tool_name, "timestamp": timestamp } ) # Add to vector store self.vector_store.add_documents([doc]) def _periodic_memory_management(self, check_frequency: int = 10, max_documents: int = 1000, max_age_days: int = 30) -> None: """Periodically manage memory to prevent unbounded growth""" # Simple probabilistic check to avoid running this too often if hash(datetime.now().isoformat()) % check_frequency != 0: return self.manage_memory(max_documents, max_age_days) def manage_memory(self, max_documents: int = 1000, max_age_days: int = 30) -> None: """ Manage memory by pruning old or less useful entries from the vector store. Args: max_documents: Maximum number of documents to keep max_age_days: Remove documents older than this many days """ print(f"Starting memory management...") # Get all documents from the vector store try: # For vector stores that have a get_all_documents method if hasattr(self.vector_store, "get_all_documents"): all_docs = self.vector_store.get_all_documents() all_ids = [doc.metadata.get("id", i) for i, doc in enumerate(all_docs)] # For other vector store implementations else: print("Warning: Vector store doesn't expose required attributes for memory management") return except Exception as e: print(f"Error accessing vector store documents: {e}") return if not all_docs: print("No documents found in vector store") return print(f"Retrieved {len(all_docs)} documents for scoring") # Score each document based on recency, importance and relevance scored_docs = [] cutoff_date = datetime.now() - timedelta(days=max_age_days) for i, doc in enumerate(all_docs): doc_id = all_ids[i] if i < len(all_ids) else i # Extract timestamp from content or metadata timestamp = None if hasattr(doc, "metadata") and doc.metadata and "timestamp" in doc.metadata: try: timestamp = datetime.fromisoformat(doc.metadata["timestamp"]) except (ValueError, TypeError): pass # If no timestamp in metadata, try to extract from content if not timestamp and hasattr(doc, "page_content") and "Timestamp:" in doc.page_content: try: timestamp_str = doc.page_content.split("Timestamp:")[-1].strip().split('\n')[0] timestamp = datetime.fromisoformat(timestamp_str) except (ValueError, TypeError): timestamp = datetime.now() - timedelta(days=max_age_days+1) # If still no timestamp, use a default if not timestamp: timestamp = datetime.now() - timedelta(days=max_age_days+1) # Calculate age score (newer is better) age_factor = max(0.0, min(1.0, (timestamp - cutoff_date).total_seconds() / (datetime.now() - cutoff_date).total_seconds())) # Calculate importance score based on document type and access frequency importance_factor = 1.0 # Tool knowledge is more valuable if hasattr(doc, "metadata") and doc.metadata and doc.metadata.get("type") == "tool_knowledge": importance_factor += 0.5 # If document has been accessed often, increase importance if hasattr(doc, "metadata") and doc.metadata and "access_count" in doc.metadata: importance_factor += min(1.0, doc.metadata["access_count"] / 10) # If document contains references to complex tools, prioritize it if hasattr(doc, "page_content"): complex_tools = ["web_search", "python_repl", "analyze_image", "arxiv_search"] if any(tool in doc.page_content for tool in complex_tools): importance_factor += 0.3 # Create combined score (higher = more valuable to keep) total_score = (0.6 * age_factor) + (0.4 * importance_factor) # Add to priority queue (negative for max-heap behavior) heapq.heappush(scored_docs, (-total_score, i, doc)) # Select top documents to keep docs_to_keep = [] for _ in range(min(max_documents, len(scored_docs))): if scored_docs: _, _, doc = heapq.heappop(scored_docs) docs_to_keep.append(doc) # Only rebuild if we're actually pruning some documents if len(docs_to_keep) < len(all_docs): print(f"Memory management: Keeping {len(docs_to_keep)} documents out of {len(all_docs)}") # Create a new vector store with the same type as the current one vector_store_type = type(self.vector_store) # Different approaches based on vector store type if hasattr(vector_store_type, "from_documents"): # Most langchain vector stores support this method new_vector_store = vector_store_type.from_documents( docs_to_keep, embedding=self.embedding_model ) self.vector_store = new_vector_store print(f"Vector store rebuilt with {len(docs_to_keep)} documents") elif hasattr(vector_store_type, "from_texts"): # For vector stores that use from_texts texts = [doc.page_content for doc in docs_to_keep] metadatas = [doc.metadata if hasattr(doc, "metadata") else {} for doc in docs_to_keep] new_vector_store = vector_store_type.from_texts( texts=texts, embedding=self.embedding_model, metadatas=metadatas ) self.vector_store = new_vector_store print(f"Vector store rebuilt with {len(docs_to_keep)} documents") else: print("Warning: Could not determine how to rebuild the vector store") print(f"Vector store type: {vector_store_type.__name__}") # Example usage if __name__ == "__main__": from langchain_huggingface import HuggingFaceEmbeddings from langchain_chroma import Chroma from langchain_groq import ChatGroq from basic_tools import multiply, add, subtract, divide, wiki_search, web_search # Initialize embeddings embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"} ) # Initialize vector store vector_store = Chroma( embedding_function=embeddings, collection_name="advanced_agent_memory" ) # Initialize LLM llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # Define tools tools = [multiply, add, subtract, divide, wiki_search, web_search] # Create agent agent = AdvancedToolAgent( embedding_model=embeddings, vector_store=vector_store, llm=llm, tools=tools ) # Test the agent response = agent("What is the population of France multiplied by 2?") print(f"Response: {response}")