Spaces:
Runtime error
Runtime error
Sheshank Joshi
commited on
Commit
·
9fced79
1
Parent(s):
554ef85
reasoning agent
Browse files- __pycache__/agent.cpython-312.pyc +0 -0
- __pycache__/basic_tools.cpython-312.pyc +0 -0
- __pycache__/reasoning_agent.cpython-312.pyc +0 -0
- __pycache__/utils.cpython-312.pyc +0 -0
- advanced_tool_agent.py +546 -0
- agent.py +25 -31
- app.py +23 -6
- basic_tools.py +618 -9
- chain_of_thought.py +34 -0
- react_agent.py +57 -0
- reasoning_agent.py +340 -0
- system_prompt.txt +37 -11
- tool_calling_agent.py +16 -0
- utils.py +57 -0
__pycache__/agent.cpython-312.pyc
CHANGED
Binary files a/__pycache__/agent.cpython-312.pyc and b/__pycache__/agent.cpython-312.pyc differ
|
|
__pycache__/basic_tools.cpython-312.pyc
CHANGED
Binary files a/__pycache__/basic_tools.cpython-312.pyc and b/__pycache__/basic_tools.cpython-312.pyc differ
|
|
__pycache__/reasoning_agent.cpython-312.pyc
ADDED
Binary file (10.6 kB). View file
|
|
__pycache__/utils.cpython-312.pyc
ADDED
Binary file (1.88 kB). View file
|
|
advanced_tool_agent.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Dict, Any, Optional, Type, Callable
|
3 |
+
from datetime import datetime, timedelta
|
4 |
+
import heapq
|
5 |
+
import json
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from langchain_core.tools import BaseTool
|
9 |
+
from langchain_core.language_models import BaseChatModel
|
10 |
+
from langchain_core.prompts import ChatPromptTemplate
|
11 |
+
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage
|
12 |
+
from langchain_core.vectorstores import VectorStore
|
13 |
+
from langchain_core.documents import Document
|
14 |
+
from langchain_core.output_parsers import StrOutputParser
|
15 |
+
from langchain.tools.retriever import create_retriever_tool
|
16 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
17 |
+
|
18 |
+
from langgraph.graph import StateGraph, END
|
19 |
+
from langgraph.prebuilt import (
|
20 |
+
ToolNode,
|
21 |
+
ToolInvocation,
|
22 |
+
agent_executor,
|
23 |
+
create_function_calling_executor,
|
24 |
+
AgentState,
|
25 |
+
MessageGraph
|
26 |
+
)
|
27 |
+
from langgraph.prebuilt.tool_executor import ToolExecutor, extract_tool_invocations
|
28 |
+
from langgraph.prebuilt.tool_nodes import get_default_tool_node_parser
|
29 |
+
|
30 |
+
|
31 |
+
class AdvancedToolAgent:
|
32 |
+
"""
|
33 |
+
An advanced agent with robust tool-calling capabilities using LangGraph.
|
34 |
+
Features enhanced memory management, context enrichment, and tool execution tracking.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
embedding_model: HuggingFaceEmbeddings,
|
40 |
+
vector_store: VectorStore,
|
41 |
+
llm: BaseChatModel,
|
42 |
+
tools: Optional[List[BaseTool]] = None,
|
43 |
+
max_iterations: int = 10,
|
44 |
+
memory_threshold: float = 0.7
|
45 |
+
):
|
46 |
+
"""
|
47 |
+
Initialize the agent with required components.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
embedding_model: Model for embedding text
|
51 |
+
vector_store: Storage for agent memory
|
52 |
+
llm: Language model for agent reasoning
|
53 |
+
tools: List of tools accessible to the agent
|
54 |
+
max_iterations: Maximum number of tool calling iterations
|
55 |
+
memory_threshold: Threshold for deciding when to include memory context (0-1)
|
56 |
+
"""
|
57 |
+
self.embedding_model = embedding_model
|
58 |
+
self.vector_store = vector_store
|
59 |
+
self.llm = llm
|
60 |
+
self.tools = tools or []
|
61 |
+
self.max_iterations = max_iterations
|
62 |
+
self.memory_threshold = memory_threshold
|
63 |
+
|
64 |
+
# Setup retriever for memory access
|
65 |
+
self.retriever = vector_store.as_retriever(
|
66 |
+
search_kwargs={"k": 3, "score_threshold": 0.75}
|
67 |
+
)
|
68 |
+
|
69 |
+
# Create memory retrieval tool
|
70 |
+
self.memory_tool = create_retriever_tool(
|
71 |
+
retriever=self.retriever,
|
72 |
+
name="memory_search",
|
73 |
+
description="Search the agent's memory for relevant past interactions and knowledge."
|
74 |
+
)
|
75 |
+
|
76 |
+
# Add memory tool to the agent's toolset
|
77 |
+
self.all_tools = self.tools + [self.memory_tool]
|
78 |
+
|
79 |
+
# Setup tool executor
|
80 |
+
self.tool_executor = ToolExecutor(self.all_tools)
|
81 |
+
|
82 |
+
# Build the agent's execution graph
|
83 |
+
self.agent_executor = self._build_agent_graph()
|
84 |
+
|
85 |
+
print(f"AdvancedToolAgent initialized with {len(self.all_tools)} tools")
|
86 |
+
|
87 |
+
def __call__(self, question: str) -> str:
|
88 |
+
"""
|
89 |
+
Process a question using the agent.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
question: The user query to respond to
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
The agent's response
|
96 |
+
"""
|
97 |
+
print(f"Agent received question: {question[:50]}..." if len(question) > 50 else question)
|
98 |
+
|
99 |
+
# Enrich context with relevant memory
|
100 |
+
enriched_input = self._enrich_context(question)
|
101 |
+
|
102 |
+
# Create initial state
|
103 |
+
initial_state = {
|
104 |
+
"messages": [HumanMessage(content=enriched_input)],
|
105 |
+
"tools": self.all_tools,
|
106 |
+
"tool_calls": [],
|
107 |
+
}
|
108 |
+
|
109 |
+
# Execute agent graph
|
110 |
+
final_state = self.agent_executor.invoke(initial_state)
|
111 |
+
|
112 |
+
# Extract the final response
|
113 |
+
final_message = final_state["messages"][-1]
|
114 |
+
answer = final_message.content
|
115 |
+
|
116 |
+
# Store this interaction in memory
|
117 |
+
self._store_interaction(question, answer, final_state.get("tool_calls", []))
|
118 |
+
|
119 |
+
# Periodically manage memory
|
120 |
+
self._periodic_memory_management()
|
121 |
+
|
122 |
+
print(f"Agent returning answer: {answer[:50]}..." if len(answer) > 50 else answer)
|
123 |
+
return answer
|
124 |
+
|
125 |
+
def _build_agent_graph(self):
|
126 |
+
"""Build the LangGraph execution graph with enhanced tool calling"""
|
127 |
+
|
128 |
+
# Function for the agent to process messages and call tools
|
129 |
+
def agent_node(state: AgentState) -> AgentState:
|
130 |
+
"""Process messages and decide on next action"""
|
131 |
+
messages = state["messages"]
|
132 |
+
|
133 |
+
# Add system instructions with tool details
|
134 |
+
if not any(isinstance(msg, SystemMessage) for msg in messages):
|
135 |
+
system_prompt = self._create_system_prompt()
|
136 |
+
messages = [SystemMessage(content=system_prompt)] + messages
|
137 |
+
|
138 |
+
# Get response from LLM
|
139 |
+
response = self.llm.invoke(messages)
|
140 |
+
|
141 |
+
# Extract any tool calls
|
142 |
+
tool_calls = extract_tool_invocations(
|
143 |
+
response,
|
144 |
+
self.all_tools,
|
145 |
+
strict_mode=False,
|
146 |
+
)
|
147 |
+
|
148 |
+
# Update state
|
149 |
+
new_state = state.copy()
|
150 |
+
new_state["messages"] = messages + [response]
|
151 |
+
new_state["tool_calls"] = tool_calls
|
152 |
+
|
153 |
+
return new_state
|
154 |
+
|
155 |
+
# Function for executing tools
|
156 |
+
def tool_node(state: AgentState) -> AgentState:
|
157 |
+
"""Execute tools and add results to messages"""
|
158 |
+
# Get the tool calls from the state
|
159 |
+
tool_calls = state["tool_calls"]
|
160 |
+
|
161 |
+
# Execute each tool call
|
162 |
+
tool_results = []
|
163 |
+
for tool_call in tool_calls:
|
164 |
+
try:
|
165 |
+
# Execute the tool
|
166 |
+
result = self.tool_executor.invoke(tool_call)
|
167 |
+
|
168 |
+
# Create a tool message with the result
|
169 |
+
tool_msg = ToolMessage(
|
170 |
+
content=str(result),
|
171 |
+
tool_call_id=tool_call.id,
|
172 |
+
name=tool_call.name,
|
173 |
+
)
|
174 |
+
tool_results.append(tool_msg)
|
175 |
+
|
176 |
+
# Track tool usage for memory
|
177 |
+
self._track_tool_usage(tool_call.name, tool_call.args, result)
|
178 |
+
except Exception as e:
|
179 |
+
# Handle tool execution errors
|
180 |
+
error_msg = f"Error executing tool {tool_call.name}: {str(e)}"
|
181 |
+
tool_msg = ToolMessage(
|
182 |
+
content=error_msg,
|
183 |
+
tool_call_id=tool_call.id,
|
184 |
+
name=tool_call.name,
|
185 |
+
)
|
186 |
+
tool_results.append(tool_msg)
|
187 |
+
|
188 |
+
# Update state with tool results
|
189 |
+
new_state = state.copy()
|
190 |
+
new_state["messages"] = state["messages"] + tool_results
|
191 |
+
new_state["tool_calls"] = []
|
192 |
+
return new_state
|
193 |
+
|
194 |
+
# Create the graph
|
195 |
+
graph = StateGraph(AgentState)
|
196 |
+
|
197 |
+
# Add nodes
|
198 |
+
graph.add_node("agent", agent_node)
|
199 |
+
graph.add_node("tools", tool_node)
|
200 |
+
|
201 |
+
# Set the entry point
|
202 |
+
graph.set_entry_point("agent")
|
203 |
+
|
204 |
+
# Add edges
|
205 |
+
graph.add_conditional_edges(
|
206 |
+
"agent",
|
207 |
+
lambda state: "tools" if state["tool_calls"] else END,
|
208 |
+
{
|
209 |
+
"tools": "tools",
|
210 |
+
END: END,
|
211 |
+
}
|
212 |
+
)
|
213 |
+
graph.add_edge("tools", "agent")
|
214 |
+
|
215 |
+
# Set max iterations to prevent infinite loops
|
216 |
+
return graph.compile(max_iterations=self.max_iterations)
|
217 |
+
|
218 |
+
def _create_system_prompt(self) -> str:
|
219 |
+
"""Create a system prompt with tool instructions"""
|
220 |
+
tool_descriptions = "\n\n".join([
|
221 |
+
f"Tool {i+1}: {tool.name}\n"
|
222 |
+
f"Description: {tool.description}\n"
|
223 |
+
f"Args: {json.dumps(tool.args, indent=2) if hasattr(tool, 'args') else 'No arguments required'}"
|
224 |
+
for i, tool in enumerate(self.all_tools)
|
225 |
+
])
|
226 |
+
|
227 |
+
return f"""You are an advanced AI assistant with access to various tools.
|
228 |
+
When a user asks a question, use your knowledge and the available tools to provide
|
229 |
+
accurate and helpful responses.
|
230 |
+
|
231 |
+
AVAILABLE TOOLS:
|
232 |
+
{tool_descriptions}
|
233 |
+
|
234 |
+
INSTRUCTIONS FOR TOOL USAGE:
|
235 |
+
1. When you need information that requires a tool, call the appropriate tool.
|
236 |
+
2. Format tool calls clearly by specifying the tool name and inputs.
|
237 |
+
3. Wait for tool results before providing final answers.
|
238 |
+
4. Use tools only when necessary - if you can answer directly, do so.
|
239 |
+
5. If a tool fails, try a different approach or tool.
|
240 |
+
6. Always explain your reasoning step by step.
|
241 |
+
|
242 |
+
Remember to be helpful, accurate, and concise in your responses.
|
243 |
+
"""
|
244 |
+
|
245 |
+
def _enrich_context(self, query: str) -> str:
|
246 |
+
"""Enrich the input query with relevant context from memory"""
|
247 |
+
# Search for similar content
|
248 |
+
similar_docs = self.vector_store.similarity_search(
|
249 |
+
query,
|
250 |
+
k=2, # Limit to 2 most relevant documents
|
251 |
+
fetch_k=5 # Consider 5 candidates
|
252 |
+
)
|
253 |
+
|
254 |
+
# Only use memory if relevance is high enough
|
255 |
+
if not similar_docs or len(similar_docs) == 0:
|
256 |
+
return query
|
257 |
+
|
258 |
+
# Build enhanced context
|
259 |
+
context_additions = []
|
260 |
+
for doc in similar_docs:
|
261 |
+
content = doc.page_content
|
262 |
+
|
263 |
+
# Extract different types of memory
|
264 |
+
if "Question:" in content and "Final answer:" in content:
|
265 |
+
# Q&A memory
|
266 |
+
q = content.split("Question:")[1].split("Final answer:")[0].strip()
|
267 |
+
a = content.split("Final answer:")[1].split("Timestamp:", 1)[0].strip()
|
268 |
+
|
269 |
+
# Only add if it's not too similar to current question
|
270 |
+
if not self._is_similar_question(query, q, threshold=0.85):
|
271 |
+
context_additions.append(f"Related Q: {q}\nRelated A: {a}")
|
272 |
+
|
273 |
+
elif "Tool Knowledge" in content:
|
274 |
+
# Tool usage memory
|
275 |
+
tool_name = content.split("Tool:")[1].split("Query:")[0].strip()
|
276 |
+
tool_result = content.split("Result:")[1].split("Timestamp:")[0].strip()
|
277 |
+
context_additions.append(
|
278 |
+
f"From prior tool use ({tool_name}): {tool_result[:200]}"
|
279 |
+
)
|
280 |
+
|
281 |
+
# Only add context if we have relevant information
|
282 |
+
if context_additions:
|
283 |
+
return (
|
284 |
+
"Consider this relevant information first:\n\n" +
|
285 |
+
"\n\n".join(context_additions[:2]) + # Limit to 2 pieces of context
|
286 |
+
"\n\nNow answering this question: " + query
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
return query
|
290 |
+
|
291 |
+
def _is_similar_question(self, query1: str, query2: str, threshold: float = 0.8) -> bool:
|
292 |
+
"""Check if two questions are semantically similar using embeddings"""
|
293 |
+
# Get embeddings for both queries
|
294 |
+
if hasattr(self.embedding_model, 'embed_query'):
|
295 |
+
emb1 = self.embedding_model.embed_query(query1)
|
296 |
+
emb2 = self.embedding_model.embed_query(query2)
|
297 |
+
|
298 |
+
# Calculate cosine similarity
|
299 |
+
similarity = self._cosine_similarity(emb1, emb2)
|
300 |
+
return similarity > threshold
|
301 |
+
return False
|
302 |
+
|
303 |
+
@staticmethod
|
304 |
+
def _cosine_similarity(v1, v2):
|
305 |
+
"""Calculate cosine similarity between vectors"""
|
306 |
+
dot_product = sum(x * y for x, y in zip(v1, v2))
|
307 |
+
magnitude1 = sum(x * x for x in v1) ** 0.5
|
308 |
+
magnitude2 = sum(x * x for x in v2) ** 0.5
|
309 |
+
if magnitude1 * magnitude2 == 0:
|
310 |
+
return 0
|
311 |
+
return dot_product / (magnitude1 * magnitude2)
|
312 |
+
|
313 |
+
def _store_interaction(self, question: str, answer: str, tool_calls: List[dict]) -> None:
|
314 |
+
"""Store the interaction in vector memory"""
|
315 |
+
timestamp = datetime.now().isoformat()
|
316 |
+
|
317 |
+
# Format tools used
|
318 |
+
tools_used = []
|
319 |
+
for tool_call in tool_calls:
|
320 |
+
if isinstance(tool_call, dict) and 'name' in tool_call:
|
321 |
+
tools_used.append(tool_call['name'])
|
322 |
+
elif hasattr(tool_call, 'name'):
|
323 |
+
tools_used.append(tool_call.name)
|
324 |
+
|
325 |
+
tools_str = ", ".join(tools_used) if tools_used else "None"
|
326 |
+
|
327 |
+
# Create content
|
328 |
+
content = (
|
329 |
+
f"Question: {question}\n"
|
330 |
+
f"Tools Used: {tools_str}\n"
|
331 |
+
f"Final answer: {answer}\n"
|
332 |
+
f"Timestamp: {timestamp}"
|
333 |
+
)
|
334 |
+
|
335 |
+
# Create document with metadata
|
336 |
+
doc = Document(
|
337 |
+
page_content=content,
|
338 |
+
metadata={
|
339 |
+
"question": question,
|
340 |
+
"timestamp": timestamp,
|
341 |
+
"type": "qa_pair",
|
342 |
+
"tools_used": tools_str
|
343 |
+
}
|
344 |
+
)
|
345 |
+
|
346 |
+
# Add to vector store
|
347 |
+
self.vector_store.add_documents([doc])
|
348 |
+
|
349 |
+
def _track_tool_usage(self, tool_name: str, tool_input: Any, tool_output: Any) -> None:
|
350 |
+
"""Track tool usage for future reference"""
|
351 |
+
timestamp = datetime.now().isoformat()
|
352 |
+
|
353 |
+
# Format the content
|
354 |
+
content = (
|
355 |
+
f"Tool Knowledge\n"
|
356 |
+
f"Tool: {tool_name}\n"
|
357 |
+
f"Query: {str(tool_input)}\n"
|
358 |
+
f"Result: {str(tool_output)}\n"
|
359 |
+
f"Timestamp: {timestamp}"
|
360 |
+
)
|
361 |
+
|
362 |
+
# Create document with metadata
|
363 |
+
doc = Document(
|
364 |
+
page_content=content,
|
365 |
+
metadata={
|
366 |
+
"type": "tool_knowledge",
|
367 |
+
"tool": tool_name,
|
368 |
+
"timestamp": timestamp
|
369 |
+
}
|
370 |
+
)
|
371 |
+
|
372 |
+
# Add to vector store
|
373 |
+
self.vector_store.add_documents([doc])
|
374 |
+
|
375 |
+
def _periodic_memory_management(self,
|
376 |
+
check_frequency: int = 10,
|
377 |
+
max_documents: int = 1000,
|
378 |
+
max_age_days: int = 30) -> None:
|
379 |
+
"""Periodically manage memory to prevent unbounded growth"""
|
380 |
+
# Simple probabilistic check to avoid running this too often
|
381 |
+
if hash(datetime.now().isoformat()) % check_frequency != 0:
|
382 |
+
return
|
383 |
+
|
384 |
+
self.manage_memory(max_documents, max_age_days)
|
385 |
+
|
386 |
+
def manage_memory(self, max_documents: int = 1000, max_age_days: int = 30) -> None:
|
387 |
+
"""
|
388 |
+
Manage memory by pruning old or less useful entries from the vector store.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
max_documents: Maximum number of documents to keep
|
392 |
+
max_age_days: Remove documents older than this many days
|
393 |
+
"""
|
394 |
+
print(f"Starting memory management...")
|
395 |
+
|
396 |
+
# Get all documents from the vector store
|
397 |
+
try:
|
398 |
+
# For vector stores that have a get_all_documents method
|
399 |
+
if hasattr(self.vector_store, "get_all_documents"):
|
400 |
+
all_docs = self.vector_store.get_all_documents()
|
401 |
+
all_ids = [doc.metadata.get("id", i) for i, doc in enumerate(all_docs)]
|
402 |
+
# For other vector store implementations
|
403 |
+
else:
|
404 |
+
print("Warning: Vector store doesn't expose required attributes for memory management")
|
405 |
+
return
|
406 |
+
except Exception as e:
|
407 |
+
print(f"Error accessing vector store documents: {e}")
|
408 |
+
return
|
409 |
+
|
410 |
+
if not all_docs:
|
411 |
+
print("No documents found in vector store")
|
412 |
+
return
|
413 |
+
|
414 |
+
print(f"Retrieved {len(all_docs)} documents for scoring")
|
415 |
+
|
416 |
+
# Score each document based on recency, importance and relevance
|
417 |
+
scored_docs = []
|
418 |
+
cutoff_date = datetime.now() - timedelta(days=max_age_days)
|
419 |
+
|
420 |
+
for i, doc in enumerate(all_docs):
|
421 |
+
doc_id = all_ids[i] if i < len(all_ids) else i
|
422 |
+
|
423 |
+
# Extract timestamp from content or metadata
|
424 |
+
timestamp = None
|
425 |
+
if hasattr(doc, "metadata") and doc.metadata and "timestamp" in doc.metadata:
|
426 |
+
try:
|
427 |
+
timestamp = datetime.fromisoformat(doc.metadata["timestamp"])
|
428 |
+
except (ValueError, TypeError):
|
429 |
+
pass
|
430 |
+
|
431 |
+
# If no timestamp in metadata, try to extract from content
|
432 |
+
if not timestamp and hasattr(doc, "page_content") and "Timestamp:" in doc.page_content:
|
433 |
+
try:
|
434 |
+
timestamp_str = doc.page_content.split("Timestamp:")[-1].strip().split('\n')[0]
|
435 |
+
timestamp = datetime.fromisoformat(timestamp_str)
|
436 |
+
except (ValueError, TypeError):
|
437 |
+
timestamp = datetime.now() - timedelta(days=max_age_days+1)
|
438 |
+
|
439 |
+
# If still no timestamp, use a default
|
440 |
+
if not timestamp:
|
441 |
+
timestamp = datetime.now() - timedelta(days=max_age_days+1)
|
442 |
+
|
443 |
+
# Calculate age score (newer is better)
|
444 |
+
age_factor = max(0.0, min(1.0, (timestamp - cutoff_date).total_seconds() /
|
445 |
+
(datetime.now() - cutoff_date).total_seconds()))
|
446 |
+
|
447 |
+
# Calculate importance score based on document type and access frequency
|
448 |
+
importance_factor = 1.0
|
449 |
+
|
450 |
+
# Tool knowledge is more valuable
|
451 |
+
if hasattr(doc, "metadata") and doc.metadata and doc.metadata.get("type") == "tool_knowledge":
|
452 |
+
importance_factor += 0.5
|
453 |
+
|
454 |
+
# If document has been accessed often, increase importance
|
455 |
+
if hasattr(doc, "metadata") and doc.metadata and "access_count" in doc.metadata:
|
456 |
+
importance_factor += min(1.0, doc.metadata["access_count"] / 10)
|
457 |
+
|
458 |
+
# If document contains references to complex tools, prioritize it
|
459 |
+
if hasattr(doc, "page_content"):
|
460 |
+
complex_tools = ["web_search", "python_repl", "analyze_image", "arxiv_search"]
|
461 |
+
if any(tool in doc.page_content for tool in complex_tools):
|
462 |
+
importance_factor += 0.3
|
463 |
+
|
464 |
+
# Create combined score (higher = more valuable to keep)
|
465 |
+
total_score = (0.6 * age_factor) + (0.4 * importance_factor)
|
466 |
+
|
467 |
+
# Add to priority queue (negative for max-heap behavior)
|
468 |
+
heapq.heappush(scored_docs, (-total_score, i, doc))
|
469 |
+
|
470 |
+
# Select top documents to keep
|
471 |
+
docs_to_keep = []
|
472 |
+
for _ in range(min(max_documents, len(scored_docs))):
|
473 |
+
if scored_docs:
|
474 |
+
_, _, doc = heapq.heappop(scored_docs)
|
475 |
+
docs_to_keep.append(doc)
|
476 |
+
|
477 |
+
# Only rebuild if we're actually pruning some documents
|
478 |
+
if len(docs_to_keep) < len(all_docs):
|
479 |
+
print(f"Memory management: Keeping {len(docs_to_keep)} documents out of {len(all_docs)}")
|
480 |
+
|
481 |
+
# Create a new vector store with the same type as the current one
|
482 |
+
vector_store_type = type(self.vector_store)
|
483 |
+
|
484 |
+
# Different approaches based on vector store type
|
485 |
+
if hasattr(vector_store_type, "from_documents"):
|
486 |
+
# Most langchain vector stores support this method
|
487 |
+
new_vector_store = vector_store_type.from_documents(
|
488 |
+
docs_to_keep,
|
489 |
+
embedding=self.embedding_model
|
490 |
+
)
|
491 |
+
self.vector_store = new_vector_store
|
492 |
+
print(f"Vector store rebuilt with {len(docs_to_keep)} documents")
|
493 |
+
|
494 |
+
elif hasattr(vector_store_type, "from_texts"):
|
495 |
+
# For vector stores that use from_texts
|
496 |
+
texts = [doc.page_content for doc in docs_to_keep]
|
497 |
+
metadatas = [doc.metadata if hasattr(doc, "metadata") else {} for doc in docs_to_keep]
|
498 |
+
|
499 |
+
new_vector_store = vector_store_type.from_texts(
|
500 |
+
texts=texts,
|
501 |
+
embedding=self.embedding_model,
|
502 |
+
metadatas=metadatas
|
503 |
+
)
|
504 |
+
self.vector_store = new_vector_store
|
505 |
+
print(f"Vector store rebuilt with {len(docs_to_keep)} documents")
|
506 |
+
|
507 |
+
else:
|
508 |
+
print("Warning: Could not determine how to rebuild the vector store")
|
509 |
+
print(f"Vector store type: {vector_store_type.__name__}")
|
510 |
+
|
511 |
+
# Example usage
|
512 |
+
if __name__ == "__main__":
|
513 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
514 |
+
from langchain_chroma import Chroma
|
515 |
+
from langchain_groq import ChatGroq
|
516 |
+
from basic_tools import multiply, add, subtract, divide, wiki_search, web_search
|
517 |
+
|
518 |
+
# Initialize embeddings
|
519 |
+
embeddings = HuggingFaceEmbeddings(
|
520 |
+
model_name="sentence-transformers/all-mpnet-base-v2",
|
521 |
+
model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}
|
522 |
+
)
|
523 |
+
|
524 |
+
# Initialize vector store
|
525 |
+
vector_store = Chroma(
|
526 |
+
embedding_function=embeddings,
|
527 |
+
collection_name="advanced_agent_memory"
|
528 |
+
)
|
529 |
+
|
530 |
+
# Initialize LLM
|
531 |
+
llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
|
532 |
+
|
533 |
+
# Define tools
|
534 |
+
tools = [multiply, add, subtract, divide, wiki_search, web_search]
|
535 |
+
|
536 |
+
# Create agent
|
537 |
+
agent = AdvancedToolAgent(
|
538 |
+
embedding_model=embeddings,
|
539 |
+
vector_store=vector_store,
|
540 |
+
llm=llm,
|
541 |
+
tools=tools
|
542 |
+
)
|
543 |
+
|
544 |
+
# Test the agent
|
545 |
+
response = agent("What is the population of France multiplied by 2?")
|
546 |
+
print(f"Response: {response}")
|
agent.py
CHANGED
@@ -7,7 +7,8 @@ from langchain.vectorstores import VectorStore
|
|
7 |
from langchain_core.language_models import BaseChatModel
|
8 |
from langgraph.prebuilt import tools_condition
|
9 |
from langgraph.prebuilt import ToolNode
|
10 |
-
from langchain_community.vectorstores import
|
|
|
11 |
from langchain_core.documents import Document
|
12 |
from langchain_groq import ChatGroq
|
13 |
from basic_tools import *
|
@@ -17,24 +18,16 @@ from datetime import datetime, timedelta
|
|
17 |
from sentence_transformers import SentenceTransformer
|
18 |
import torch
|
19 |
import heapq
|
|
|
20 |
|
21 |
os.environ['HF_HOME'] = os.path.join(
|
22 |
os.path.expanduser('~'), '.cache', "huggingface")
|
23 |
|
24 |
-
|
25 |
-
model_name="sentence-transformers/all-mpnet-base-v2",
|
26 |
-
# hugging_face_api_key=os.getenv("HF_TOKEN"),
|
27 |
-
model_kwargs={"device": "gpu" if torch.cuda.is_available() else "cpu",
|
28 |
-
"token": os.getenv("HF_TOKEN")},
|
29 |
-
show_progress=True,
|
30 |
-
)
|
31 |
-
vector_store: FAISS = FAISS.from_texts(
|
32 |
-
texts=[],
|
33 |
-
embedding=embeddings)
|
34 |
|
35 |
|
36 |
# load the system prompt from the file
|
37 |
-
with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
38 |
system_prompt = f.read()
|
39 |
|
40 |
|
@@ -42,19 +35,13 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
|
42 |
sys_msg = SystemMessage(content=system_prompt)
|
43 |
|
44 |
|
45 |
-
|
46 |
-
|
47 |
class BasicAgent:
|
48 |
tools: List[BaseTool] = [multiply,
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
web_search,
|
55 |
-
arxiv_search,
|
56 |
-
requests_get,
|
57 |
-
requests_post
|
58 |
]
|
59 |
def __init__(self, embeddings: HuggingFaceEmbeddings, vector_store: VectorStore, llm: BaseChatModel):
|
60 |
self.embedding_model = embeddings
|
@@ -72,12 +59,12 @@ class BasicAgent:
|
|
72 |
def __call__(self, question: str) -> str:
|
73 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
74 |
|
75 |
-
# Search for similar content to enhance context
|
76 |
-
similar_docs = self.vector_store.similarity_search(question, k=3
|
77 |
|
78 |
# Create enhanced context with relevant past information
|
79 |
enhanced_context = question
|
80 |
-
if similar_docs:
|
81 |
context_additions = []
|
82 |
for doc in similar_docs:
|
83 |
# Extract relevant information from similar documents
|
@@ -85,15 +72,22 @@ class BasicAgent:
|
|
85 |
if "Question:" in content and "Final answer:" in content:
|
86 |
q = content.split("Question:")[1].split("Final answer:")[0].strip()
|
87 |
a = content.split("Final answer:")[1].split("Timestamp:", 1)[0].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
# Only add if it's not exactly the same question
|
89 |
if not question.lower() == q.lower():
|
90 |
context_additions.append(f"Related Q: {q}\nRelated A: {a}")
|
91 |
|
92 |
if context_additions:
|
93 |
enhanced_context = (
|
94 |
-
"
|
95 |
-
"\n\n".join(context_additions) +
|
96 |
-
"\n\nNow answering
|
97 |
)
|
98 |
|
99 |
# Process with the graph
|
@@ -189,7 +183,7 @@ class BasicAgent:
|
|
189 |
tools_condition,
|
190 |
{
|
191 |
"tools": "tools",
|
192 |
-
|
193 |
}
|
194 |
)
|
195 |
builder.add_edge("tools", "context_enhanced_generation")
|
@@ -218,7 +212,7 @@ class BasicAgent:
|
|
218 |
base_url="http://localhost:11432/v1", # default LM Studio endpoint
|
219 |
api_key="not-used", # required by interface but ignored #type: ignore
|
220 |
# model="mistral-nemo-instruct-2407",
|
221 |
-
model="llama-3.1-8b-
|
222 |
temperature=0.2
|
223 |
)
|
224 |
elif provider == "openai":
|
|
|
7 |
from langchain_core.language_models import BaseChatModel
|
8 |
from langgraph.prebuilt import tools_condition
|
9 |
from langgraph.prebuilt import ToolNode
|
10 |
+
# from langchain_community.vectorstores import Chroma
|
11 |
+
|
12 |
from langchain_core.documents import Document
|
13 |
from langchain_groq import ChatGroq
|
14 |
from basic_tools import *
|
|
|
18 |
from sentence_transformers import SentenceTransformer
|
19 |
import torch
|
20 |
import heapq
|
21 |
+
from utils import *
|
22 |
|
23 |
os.environ['HF_HOME'] = os.path.join(
|
24 |
os.path.expanduser('~'), '.cache', "huggingface")
|
25 |
|
26 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
# load the system prompt from the file
|
30 |
+
with open("./system_prompt.txt", "r", encoding="utf-8") as f:
|
31 |
system_prompt = f.read()
|
32 |
|
33 |
|
|
|
35 |
sys_msg = SystemMessage(content=system_prompt)
|
36 |
|
37 |
|
|
|
|
|
38 |
class BasicAgent:
|
39 |
tools: List[BaseTool] = [multiply,
|
40 |
+
multiply, add, subtract, divide, modulus,
|
41 |
+
wiki_search, web_search, arxiv_search,
|
42 |
+
python_repl, analyze_image,
|
43 |
+
date_filter, analyze_content,
|
44 |
+
step_by_step_reasoning, translate_text
|
|
|
|
|
|
|
|
|
45 |
]
|
46 |
def __init__(self, embeddings: HuggingFaceEmbeddings, vector_store: VectorStore, llm: BaseChatModel):
|
47 |
self.embedding_model = embeddings
|
|
|
59 |
def __call__(self, question: str) -> str:
|
60 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
61 |
|
62 |
+
# Search for similar content to enhance context - LIMIT TO 1 DOCUMENT ONLY
|
63 |
+
similar_docs = self.vector_store.similarity_search(question, k=1) # Reduced from 3 to 1
|
64 |
|
65 |
# Create enhanced context with relevant past information
|
66 |
enhanced_context = question
|
67 |
+
if (similar_docs):
|
68 |
context_additions = []
|
69 |
for doc in similar_docs:
|
70 |
# Extract relevant information from similar documents
|
|
|
72 |
if "Question:" in content and "Final answer:" in content:
|
73 |
q = content.split("Question:")[1].split("Final answer:")[0].strip()
|
74 |
a = content.split("Final answer:")[1].split("Timestamp:", 1)[0].strip()
|
75 |
+
|
76 |
+
# Truncate long contexts
|
77 |
+
if len(q) > 200:
|
78 |
+
q = q[:200] + "..."
|
79 |
+
if len(a) > 300:
|
80 |
+
a = a[:300] + "..."
|
81 |
+
|
82 |
# Only add if it's not exactly the same question
|
83 |
if not question.lower() == q.lower():
|
84 |
context_additions.append(f"Related Q: {q}\nRelated A: {a}")
|
85 |
|
86 |
if context_additions:
|
87 |
enhanced_context = (
|
88 |
+
"Consider this relevant information first:\n\n" +
|
89 |
+
"\n\n".join(context_additions[:1]) + # Only use the first context addition
|
90 |
+
"\n\nNow answering this question: " + question
|
91 |
)
|
92 |
|
93 |
# Process with the graph
|
|
|
183 |
tools_condition,
|
184 |
{
|
185 |
"tools": "tools",
|
186 |
+
END: END # Using END as the key instead of None
|
187 |
}
|
188 |
)
|
189 |
builder.add_edge("tools", "context_enhanced_generation")
|
|
|
212 |
base_url="http://localhost:11432/v1", # default LM Studio endpoint
|
213 |
api_key="not-used", # required by interface but ignored #type: ignore
|
214 |
# model="mistral-nemo-instruct-2407",
|
215 |
+
model="meta-llama-3.1-8b-instruct",
|
216 |
temperature=0.2
|
217 |
)
|
218 |
elif provider == "openai":
|
app.py
CHANGED
@@ -3,7 +3,9 @@ import gradio as gr
|
|
3 |
import requests
|
4 |
import inspect
|
5 |
import pandas as pd
|
6 |
-
from agent import BasicAgent, embeddings, vector_store
|
|
|
|
|
7 |
from dotenv import load_dotenv
|
8 |
import os
|
9 |
|
@@ -39,8 +41,9 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
39 |
# 1. Instantiate Agent ( modify this part to create your agent)
|
40 |
try:
|
41 |
# llm = BasicAgent.get_llm("groq")
|
42 |
-
llm = BasicAgent.get_llm("openai_local")
|
43 |
-
agent = BasicAgent(embeddings, vector_store, llm)
|
|
|
44 |
print("Agent instantiated successfully.")
|
45 |
except Exception as e:
|
46 |
print(f"Error instantiating agent: {e}")
|
@@ -74,19 +77,33 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
74 |
results_log = []
|
75 |
answers_payload = []
|
76 |
print(f"Running agent on {len(questions_data)} questions...")
|
77 |
-
|
|
|
|
|
|
|
|
|
78 |
task_id = item.get("task_id")
|
79 |
question_text = item.get("question")
|
80 |
if not task_id or question_text is None:
|
81 |
print(f"Skipping item with missing task_id or question: {item}")
|
82 |
continue
|
|
|
83 |
try:
|
|
|
84 |
submitted_answer = agent(question_text)
|
85 |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
86 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
except Exception as e:
|
88 |
-
|
89 |
-
|
|
|
|
|
90 |
|
91 |
if not answers_payload:
|
92 |
print("Agent did not produce any answers to submit.")
|
|
|
3 |
import requests
|
4 |
import inspect
|
5 |
import pandas as pd
|
6 |
+
# from agent import BasicAgent, embeddings, vector_store
|
7 |
+
from utils import embeddings, vector_store
|
8 |
+
from reasoning_agent import ReasoningAgent
|
9 |
from dotenv import load_dotenv
|
10 |
import os
|
11 |
|
|
|
41 |
# 1. Instantiate Agent ( modify this part to create your agent)
|
42 |
try:
|
43 |
# llm = BasicAgent.get_llm("groq")
|
44 |
+
# llm = BasicAgent.get_llm("openai_local")
|
45 |
+
# agent = BasicAgent(embeddings, vector_store, llm)
|
46 |
+
agent = ReasoningAgent()
|
47 |
print("Agent instantiated successfully.")
|
48 |
except Exception as e:
|
49 |
print(f"Error instantiating agent: {e}")
|
|
|
77 |
results_log = []
|
78 |
answers_payload = []
|
79 |
print(f"Running agent on {len(questions_data)} questions...")
|
80 |
+
|
81 |
+
import time
|
82 |
+
|
83 |
+
# Process at most 5 questions at a time to avoid rate limits
|
84 |
+
for i, item in enumerate(questions_data):
|
85 |
task_id = item.get("task_id")
|
86 |
question_text = item.get("question")
|
87 |
if not task_id or question_text is None:
|
88 |
print(f"Skipping item with missing task_id or question: {item}")
|
89 |
continue
|
90 |
+
|
91 |
try:
|
92 |
+
print(f"Processing question {i+1}/{len(questions_data)}: {task_id}")
|
93 |
submitted_answer = agent(question_text)
|
94 |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
95 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
96 |
+
|
97 |
+
# Add delay between questions to avoid rate limiting (5 seconds)
|
98 |
+
if i < len(questions_data) - 1:
|
99 |
+
print(f"Waiting 5 seconds before next question...")
|
100 |
+
time.sleep(5)
|
101 |
+
|
102 |
except Exception as e:
|
103 |
+
print(f"Error running agent on task {task_id}: {e}")
|
104 |
+
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
|
105 |
+
# Continue with the next question after a short delay even if there was an error
|
106 |
+
time.sleep(3)
|
107 |
|
108 |
if not answers_payload:
|
109 |
print("Agent did not produce any answers to submit.")
|
basic_tools.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
from dotenv import load_dotenv
|
3 |
from langgraph.graph import START, StateGraph, MessagesState
|
@@ -80,17 +84,27 @@ def modulus(a: int, b: int) -> int:
|
|
80 |
|
81 |
@tool
|
82 |
def wiki_search(query: str) -> str:
|
83 |
-
"""Search Wikipedia for a query and return maximum
|
84 |
|
85 |
Args:
|
86 |
-
query: The search query.
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
|
96 |
@tool
|
@@ -160,5 +174,600 @@ requests_post = RequestsPostTool(requests_wrapper=requests_wrapper, allow_danger
|
|
160 |
# response = toolkit.run(url, data=data, json=json, headers=headers)
|
161 |
# return response.text
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
|
|
|
1 |
+
from youtube_transcript_api.formatters import TextFormatter
|
2 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
3 |
+
import requests
|
4 |
+
from typing import Dict, List, Optional, Any, Union
|
5 |
import os
|
6 |
from dotenv import load_dotenv
|
7 |
from langgraph.graph import START, StateGraph, MessagesState
|
|
|
84 |
|
85 |
@tool
|
86 |
def wiki_search(query: str) -> str:
|
87 |
+
"""Search Wikipedia for a query and return maximum 5 results.
|
88 |
|
89 |
Args:
|
90 |
+
query: The search query. Be specific with search terms including full names, dates, and relevant keywords.
|
91 |
+
"""
|
92 |
+
if not query or query.strip() == "":
|
93 |
+
return "Error: Please provide a valid search query with specific terms."
|
94 |
+
|
95 |
+
try:
|
96 |
+
search_docs = WikipediaLoader(query=query, load_max_docs=5).load()
|
97 |
+
if not search_docs:
|
98 |
+
return f"No Wikipedia results found for '{query}'. Consider refining your search terms."
|
99 |
+
|
100 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
101 |
+
[
|
102 |
+
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
103 |
+
for doc in search_docs
|
104 |
+
])
|
105 |
+
return formatted_search_docs
|
106 |
+
except Exception as e:
|
107 |
+
return f"Error searching Wikipedia: {str(e)}. Please try a different query."
|
108 |
|
109 |
|
110 |
@tool
|
|
|
174 |
# response = toolkit.run(url, data=data, json=json, headers=headers)
|
175 |
# return response.text
|
176 |
|
177 |
+
@tool
|
178 |
+
def date_filter(content: str, start_year: int, end_year: int) -> str:
|
179 |
+
"""Filter content based on date range and extract relevant information.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
content: The text content to filter
|
183 |
+
start_year: Starting year (inclusive)
|
184 |
+
end_year: Ending year (inclusive)
|
185 |
+
"""
|
186 |
+
if not content or not isinstance(content, str):
|
187 |
+
return "Error: No content provided for filtering."
|
188 |
+
|
189 |
+
try:
|
190 |
+
# Convert years to strings for matching
|
191 |
+
years = [str(year) for year in range(start_year, end_year + 1)]
|
192 |
+
|
193 |
+
# Split content into paragraphs
|
194 |
+
paragraphs = content.split("\n")
|
195 |
+
|
196 |
+
# Filter paragraphs containing any year in the range
|
197 |
+
filtered_paragraphs = []
|
198 |
+
for paragraph in paragraphs:
|
199 |
+
if any(f" {year}" in paragraph or f"({year})" in paragraph or f"[{year}]" in paragraph for year in years):
|
200 |
+
filtered_paragraphs.append(paragraph)
|
201 |
+
|
202 |
+
if not filtered_paragraphs:
|
203 |
+
return f"No content found specifically mentioning years between {start_year} and {end_year}."
|
204 |
+
|
205 |
+
return "\n\n".join(filtered_paragraphs)
|
206 |
+
except Exception as e:
|
207 |
+
return f"Error filtering by date range: {str(e)}"
|
208 |
+
|
209 |
+
import re
|
210 |
+
|
211 |
+
@tool
|
212 |
+
def count_items(content: str, pattern: str, context_words: int = 5) -> str:
|
213 |
+
"""Count items matching a pattern in content and extract contextual information.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
content: The text to analyze
|
217 |
+
pattern: The pattern to search for (e.g. "album", "publication")
|
218 |
+
context_words: Number of words to include for context around matches
|
219 |
+
"""
|
220 |
+
if not content or not pattern:
|
221 |
+
return "Error: Both content and pattern must be provided."
|
222 |
+
|
223 |
+
try:
|
224 |
+
# Find all occurrences of the pattern
|
225 |
+
matches = re.finditer(r'(?i)\b\w*' + re.escape(pattern) + r'\w*\b', content)
|
226 |
+
|
227 |
+
# Extract context around matches
|
228 |
+
contexts = []
|
229 |
+
count = 0
|
230 |
+
|
231 |
+
for match in matches:
|
232 |
+
count += 1
|
233 |
+
start, end = match.span()
|
234 |
+
|
235 |
+
# Get text before and after the match
|
236 |
+
text_before = content[max(0, start-100):start]
|
237 |
+
text_after = content[end:min(len(content), end+100)]
|
238 |
+
|
239 |
+
# Create contextual excerpt
|
240 |
+
context = f"...{text_before}{match.group(0)}{text_after}..."
|
241 |
+
contexts.append(context)
|
242 |
+
|
243 |
+
if count == 0:
|
244 |
+
return f"No items matching '{pattern}' found in the content."
|
245 |
+
|
246 |
+
result = f"Found {count} occurrences of '{pattern}'. Contexts:\n\n"
|
247 |
+
result += "\n---\n".join(contexts[:10]) # Limit to first 10 for brevity
|
248 |
+
|
249 |
+
return result
|
250 |
+
except Exception as e:
|
251 |
+
return f"Error counting items: {str(e)}"
|
252 |
+
|
253 |
+
@tool
|
254 |
+
def translate_text(text: str, target_language: str) -> str:
|
255 |
+
"""Translate text to the specified language using a simple translation API.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
text: Text to translate
|
259 |
+
target_language: Target language (e.g., "Spanish", "French", "German")
|
260 |
+
"""
|
261 |
+
if not text:
|
262 |
+
return "Error: No text provided for translation."
|
263 |
+
|
264 |
+
try:
|
265 |
+
# Using LibreTranslate API (open-source translation)
|
266 |
+
API_URL = "https://translate.argosopentech.com/translate"
|
267 |
+
|
268 |
+
# Map common language names to language codes
|
269 |
+
language_map = {
|
270 |
+
"english": "en",
|
271 |
+
"spanish": "es",
|
272 |
+
"french": "fr",
|
273 |
+
"german": "de",
|
274 |
+
"italian": "it",
|
275 |
+
"portuguese": "pt",
|
276 |
+
"russian": "ru",
|
277 |
+
"japanese": "ja",
|
278 |
+
"chinese": "zh",
|
279 |
+
"arabic": "ar",
|
280 |
+
"hindi": "hi",
|
281 |
+
"korean": "ko"
|
282 |
+
}
|
283 |
+
|
284 |
+
# Get language code
|
285 |
+
target_code = language_map.get(target_language.lower())
|
286 |
+
if not target_code:
|
287 |
+
return f"Error: Unsupported language '{target_language}'. Supported languages: {', '.join(language_map.keys())}."
|
288 |
+
|
289 |
+
# Prepare request
|
290 |
+
payload = {
|
291 |
+
"q": text[:500], # Limit text length to avoid API issues
|
292 |
+
"source": "auto",
|
293 |
+
"target": target_code
|
294 |
+
}
|
295 |
+
|
296 |
+
response = requests.post(API_URL, json=payload)
|
297 |
+
if response.status_code == 200:
|
298 |
+
translation = response.json().get("translatedText", "")
|
299 |
+
return f"Original: {text[:100]}{'...' if len(text) > 100 else ''}\n\nTranslation ({target_language}): {translation}"
|
300 |
+
else:
|
301 |
+
return f"Translation API error: {response.status_code} - {response.text}"
|
302 |
+
except Exception as e:
|
303 |
+
return f"Error translating text: {str(e)}"
|
304 |
+
|
305 |
+
@tool
|
306 |
+
def step_by_step_reasoning(problem: str, steps: int = 3) -> str:
|
307 |
+
"""Break down a complex problem into steps for clearer reasoning.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
problem: The problem statement or question to analyze
|
311 |
+
steps: Number of reasoning steps (default: 3)
|
312 |
+
"""
|
313 |
+
if not problem:
|
314 |
+
return "Error: No problem provided for analysis."
|
315 |
+
|
316 |
+
try:
|
317 |
+
# Structure for breaking down any problem
|
318 |
+
result = f"Breaking down: {problem}\n\n"
|
319 |
+
|
320 |
+
# Generic reasoning steps that work for many problems
|
321 |
+
reasoning_steps = [
|
322 |
+
"Identify the key information and requirements in the problem",
|
323 |
+
"Determine what knowledge or method is needed to solve it",
|
324 |
+
"Apply relevant formulas, data, or logical steps",
|
325 |
+
"Verify the solution against the original requirements",
|
326 |
+
"Consider alternative approaches or edge cases"
|
327 |
+
]
|
328 |
+
|
329 |
+
# Use only the requested number of steps
|
330 |
+
steps_to_use = min(steps, len(reasoning_steps))
|
331 |
+
for i in range(steps_to_use):
|
332 |
+
result += f"Step {i+1}: {reasoning_steps[i]}\n"
|
333 |
+
result += f"This step involves analyzing {problem} by "
|
334 |
+
|
335 |
+
if i == 0:
|
336 |
+
# First step focuses on understanding the problem
|
337 |
+
keywords = re.findall(r'\b\w{5,}\b', problem)
|
338 |
+
key_concepts = [word for word in keywords if len(word) > 4][:3]
|
339 |
+
if key_concepts:
|
340 |
+
result += f"identifying key concepts like {', '.join(key_concepts)}. "
|
341 |
+
|
342 |
+
# Identify question type
|
343 |
+
if "how many" in problem.lower():
|
344 |
+
result += "This is a counting or quantification problem. "
|
345 |
+
elif "when" in problem.lower():
|
346 |
+
result += "This is a timing or chronological problem. "
|
347 |
+
elif "where" in problem.lower():
|
348 |
+
result += "This is a location or spatial problem. "
|
349 |
+
elif "who" in problem.lower():
|
350 |
+
result += "This is a person or entity identification problem. "
|
351 |
+
elif "why" in problem.lower():
|
352 |
+
result += "This is a causation or reasoning problem. "
|
353 |
+
|
354 |
+
result += "We need to extract specific details from the problem statement.\n\n"
|
355 |
+
|
356 |
+
elif i == 1:
|
357 |
+
# Second step focuses on approach
|
358 |
+
if "between" in problem.lower() and re.search(r'\d{4}', problem):
|
359 |
+
result += "using date filtering to focus on the specific time period. "
|
360 |
+
result += "We need to identify relevant dates and associated events/items.\n\n"
|
361 |
+
elif any(word in problem.lower() for word in ["album", "song", "music", "artist", "band"]):
|
362 |
+
result += "examining discography information and music-related details. "
|
363 |
+
result += "We should focus on releases, titles, and years.\n\n"
|
364 |
+
elif any(word in problem.lower() for word in ["calculate", "compute", "sum", "average", "total"]):
|
365 |
+
result += "applying mathematical operations to derive a numeric result. "
|
366 |
+
result += "We need to identify the values and operations required.\n\n"
|
367 |
+
else:
|
368 |
+
result += "gathering relevant factual information and organizing it logically. "
|
369 |
+
result += "We should separate facts from assumptions.\n\n"
|
370 |
+
|
371 |
+
elif i == 2:
|
372 |
+
# Third step focuses on solution path
|
373 |
+
result += "determining the specific steps to reach a solution. "
|
374 |
+
result += "This may involve counting items, applying formulas, or comparing data.\n\n"
|
375 |
+
|
376 |
+
elif i == 3:
|
377 |
+
# Fourth step focuses on verification
|
378 |
+
result += "checking our answer against the original question requirements. "
|
379 |
+
result += "We should verify that we've fully addressed all parts of the question.\n\n"
|
380 |
+
|
381 |
+
else:
|
382 |
+
# Fifth step focuses on alternatives
|
383 |
+
result += "considering other approaches or edge cases we might have missed. "
|
384 |
+
result += "This ensures our answer is robust and comprehensive.\n\n"
|
385 |
+
|
386 |
+
result += "\nThis structured approach helps organize thinking and ensures a thorough analysis."
|
387 |
+
return result
|
388 |
+
|
389 |
+
except Exception as e:
|
390 |
+
return f"Error performing step-by-step reasoning: {str(e)}"
|
391 |
+
|
392 |
+
@tool
|
393 |
+
def analyze_content(content: str, analysis_type: str) -> str:
|
394 |
+
"""Analyze content for specific information based on analysis type.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
content: Text content to analyze
|
398 |
+
analysis_type: Type of analysis to perform ('dates', 'names', 'numbers', 'events')
|
399 |
+
"""
|
400 |
+
if not content:
|
401 |
+
return "Error: No content provided for analysis."
|
402 |
+
|
403 |
+
analysis_type = analysis_type.lower()
|
404 |
+
|
405 |
+
try:
|
406 |
+
if analysis_type == 'dates':
|
407 |
+
# Extract dates in various formats
|
408 |
+
date_patterns = [
|
409 |
+
r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', # DD/MM/YYYY or MM/DD/YYYY
|
410 |
+
r'\b\d{1,2}\s(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s\d{2,4}\b', # DD Month YYYY
|
411 |
+
r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s\d{1,2}(?:st|nd|rd|th)?,\s\d{2,4}\b', # Month DD, YYYY
|
412 |
+
r'\b\d{4}\b' # YYYY (years)
|
413 |
+
]
|
414 |
+
results = []
|
415 |
+
for pattern in date_patterns:
|
416 |
+
matches = re.findall(pattern, content, re.IGNORECASE)
|
417 |
+
results.extend(matches)
|
418 |
+
|
419 |
+
return f"Found {len(results)} date references:\n\n" + "\n".join(results)
|
420 |
+
|
421 |
+
elif analysis_type == 'names':
|
422 |
+
# Basic name extraction (this is simplified, real NER would be better)
|
423 |
+
name_pattern = r'\b[A-Z][a-z]+\s[A-Z][a-z]+\b'
|
424 |
+
names = re.findall(name_pattern, content)
|
425 |
+
return f"Found {len(names)} potential names:\n\n" + "\n".join(names)
|
426 |
+
|
427 |
+
elif analysis_type == 'numbers':
|
428 |
+
# Extract numbers and their context
|
429 |
+
number_pattern = r'\b\d+(?:,\d+)*(?:\.\d+)?\b'
|
430 |
+
numbers = re.findall(number_pattern, content)
|
431 |
+
|
432 |
+
# Get context for each number
|
433 |
+
contexts = []
|
434 |
+
for number in numbers:
|
435 |
+
index = content.find(number)
|
436 |
+
start = max(0, index - 50)
|
437 |
+
end = min(len(content), index + len(number) + 50)
|
438 |
+
context = content[start:end].replace('\n', ' ').strip()
|
439 |
+
contexts.append(f"{number}: \"{context}\"")
|
440 |
+
|
441 |
+
return f"Found {len(numbers)} numbers with context:\n\n" + "\n".join(contexts[:20]) # Limit to 20
|
442 |
+
|
443 |
+
elif analysis_type == 'events':
|
444 |
+
# Look for event indicators
|
445 |
+
event_patterns = [
|
446 |
+
r'\b(?:occurred|happened|took place|event|ceremony|concert|release|published|awarded|presented)\b',
|
447 |
+
r'\b(?:in|on|during|at)\s\d{4}\b'
|
448 |
+
]
|
449 |
+
events = []
|
450 |
+
for pattern in event_patterns:
|
451 |
+
for match in re.finditer(pattern, content, re.IGNORECASE):
|
452 |
+
start = max(0, match.start() - 100)
|
453 |
+
end = min(len(content), match.end() + 100)
|
454 |
+
context = content[start:end].replace('\n', ' ').strip()
|
455 |
+
events.append(context)
|
456 |
+
|
457 |
+
return f"Found {len(events)} potential events:\n\n" + "\n\n".join(events[:15]) # Limit to 15
|
458 |
+
|
459 |
+
else:
|
460 |
+
return f"Error: Unsupported analysis type '{analysis_type}'. Use 'dates', 'names', 'numbers', or 'events'."
|
461 |
+
|
462 |
+
except Exception as e:
|
463 |
+
return f"Error during content analysis: {str(e)}"
|
464 |
+
|
465 |
+
|
466 |
+
@tool
|
467 |
+
def youtube_transcript(url: str, summarize: bool = True) -> str:
|
468 |
+
"""Extract transcript from YouTube video and optionally summarize it.
|
469 |
+
|
470 |
+
Args:
|
471 |
+
url: YouTube video URL or video ID
|
472 |
+
summarize: Whether to summarize the transcript (default: True)
|
473 |
+
"""
|
474 |
+
try:
|
475 |
+
# Extract video ID from URL
|
476 |
+
video_id_match = re.search(r'(?:v=|\/)([0-9A-Za-z_-]{11}).*', url)
|
477 |
+
if video_id_match:
|
478 |
+
video_id = video_id_match.group(1)
|
479 |
+
else:
|
480 |
+
# Try using the input directly as a video ID
|
481 |
+
if len(url) == 11:
|
482 |
+
video_id = url
|
483 |
+
else:
|
484 |
+
return "Error: Invalid YouTube URL or video ID. Please provide a valid YouTube URL."
|
485 |
+
|
486 |
+
# Get transcript
|
487 |
+
transcript = YouTubeTranscriptApi.get_transcript(video_id)
|
488 |
+
formatter = TextFormatter()
|
489 |
+
formatted_transcript = formatter.format_transcript(transcript)
|
490 |
+
|
491 |
+
# Get video metadata
|
492 |
+
response = requests.get(
|
493 |
+
f"https://www.youtube.com/oembed?url=http://www.youtube.com/watch?v={video_id}&format=json")
|
494 |
+
metadata = response.json()
|
495 |
+
title = metadata.get("title", "Unknown title")
|
496 |
+
author = metadata.get("author_name", "Unknown author")
|
497 |
+
|
498 |
+
if summarize and formatted_transcript:
|
499 |
+
# For long transcripts, break into chunks
|
500 |
+
max_chunk_length = 4000
|
501 |
+
if len(formatted_transcript) > max_chunk_length:
|
502 |
+
chunks = [formatted_transcript[i:i+max_chunk_length]
|
503 |
+
for i in range(0, len(formatted_transcript), max_chunk_length)]
|
504 |
+
summary = f"Video: \"{title}\" by {author}\n\nTranscript summary (extracted from {len(chunks)} segments):\n\n"
|
505 |
+
|
506 |
+
# Return first and last parts of transcript instead of full summary for long videos
|
507 |
+
summary += f"Beginning of transcript:\n{chunks[0][:500]}...\n\n"
|
508 |
+
summary += f"End of transcript:\n{chunks[-1][-500:]}"
|
509 |
+
return summary
|
510 |
+
else:
|
511 |
+
return f"Video: \"{title}\" by {author}\n\nFull transcript:\n\n{formatted_transcript}"
|
512 |
+
else:
|
513 |
+
return f"Video: \"{title}\" by {author}\n\nFull transcript:\n\n{formatted_transcript}"
|
514 |
+
|
515 |
+
except Exception as e:
|
516 |
+
return f"Error extracting YouTube transcript: {str(e)}"
|
517 |
+
|
518 |
+
import base64
|
519 |
+
from io import BytesIO
|
520 |
+
from PIL import Image
|
521 |
+
import json
|
522 |
+
|
523 |
+
@tool
|
524 |
+
def analyze_image(image_url: str, analysis_type: str = "caption") -> str:
|
525 |
+
"""Analyze an image from a URL and provide captions, tags, or comprehensive analysis.
|
526 |
+
|
527 |
+
Args:
|
528 |
+
image_url: URL of the image to analyze
|
529 |
+
analysis_type: Type of analysis to perform (options: "caption", "tags", "objects", "comprehensive")
|
530 |
+
"""
|
531 |
+
if not image_url:
|
532 |
+
return "Error: Please provide a valid image URL."
|
533 |
+
|
534 |
+
analysis_type = analysis_type.lower()
|
535 |
+
valid_types = ["caption", "tags", "objects", "comprehensive"]
|
536 |
+
|
537 |
+
if analysis_type not in valid_types:
|
538 |
+
return f"Error: analysis_type must be one of {', '.join(valid_types)}."
|
539 |
+
|
540 |
+
try:
|
541 |
+
# Download the image
|
542 |
+
response = requests.get(image_url, timeout=10)
|
543 |
+
response.raise_for_status()
|
544 |
+
|
545 |
+
# Process image based on analysis type
|
546 |
+
if analysis_type == "caption":
|
547 |
+
return caption_image(response.content)
|
548 |
+
elif analysis_type == "tags":
|
549 |
+
return tag_image(response.content)
|
550 |
+
elif analysis_type == "objects":
|
551 |
+
return detect_objects(response.content)
|
552 |
+
elif analysis_type == "comprehensive":
|
553 |
+
# Perform all analyses
|
554 |
+
caption_result = caption_image(response.content)
|
555 |
+
tags_result = tag_image(response.content)
|
556 |
+
objects_result = detect_objects(response.content)
|
557 |
+
|
558 |
+
return f"IMAGE ANALYSIS SUMMARY:\n\n{caption_result}\n\n{tags_result}\n\n{objects_result}"
|
559 |
+
# If none of the above conditions are met, return an error string
|
560 |
+
return "Error: Unknown analysis type or failed to process image."
|
561 |
+
except requests.exceptions.RequestException as e:
|
562 |
+
return f"Error downloading image: {str(e)}"
|
563 |
+
except Exception as e:
|
564 |
+
return f"Error analyzing image: {str(e)}"
|
565 |
+
|
566 |
+
def caption_image(image_content: bytes) -> str:
|
567 |
+
"""Generate captions for an image using Hugging Face API."""
|
568 |
+
try:
|
569 |
+
# Check if we have HF API key in environment
|
570 |
+
hf_api_key = os.getenv("HUGGINGFACE_API_TOKEN")
|
571 |
+
|
572 |
+
if hf_api_key:
|
573 |
+
# Use Hugging Face API with auth
|
574 |
+
api_url = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
|
575 |
+
headers = {"Authorization": f"Bearer {hf_api_key}"}
|
576 |
+
|
577 |
+
# Convert image to base64
|
578 |
+
image_b64 = base64.b64encode(image_content).decode("utf-8")
|
579 |
+
payload = {"inputs": {"image": image_b64}}
|
580 |
+
|
581 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
582 |
+
if response.status_code == 200:
|
583 |
+
result = response.json()
|
584 |
+
if isinstance(result, list) and len(result) > 0:
|
585 |
+
return f"CAPTION: {result[0]['generated_text']}"
|
586 |
+
else:
|
587 |
+
return f"CAPTION: {result['generated_text'] if 'generated_text' in result else str(result)}"
|
588 |
+
else:
|
589 |
+
# Fallback to public API
|
590 |
+
return caption_image_public(image_content)
|
591 |
+
else:
|
592 |
+
# No API key, use public endpoint
|
593 |
+
return caption_image_public(image_content)
|
594 |
+
|
595 |
+
except Exception as e:
|
596 |
+
return f"Error generating caption: {str(e)}"
|
597 |
+
|
598 |
+
def caption_image_public(image_content: bytes) -> str:
|
599 |
+
"""Caption image using a public API endpoint."""
|
600 |
+
try:
|
601 |
+
# Convert to PIL image for processing
|
602 |
+
image = Image.open(BytesIO(image_content))
|
603 |
+
|
604 |
+
# Resize if too large (to avoid timeouts)
|
605 |
+
max_size = 1024
|
606 |
+
if max(image.size) > max_size:
|
607 |
+
ratio = max_size / max(image.size)
|
608 |
+
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
|
609 |
+
image = image.resize(new_size, Image.LANCZOS)
|
610 |
+
|
611 |
+
# Convert back to bytes
|
612 |
+
buffer = BytesIO()
|
613 |
+
image.save(buffer, format="JPEG")
|
614 |
+
image_bytes = buffer.getvalue()
|
615 |
+
|
616 |
+
# Call public API
|
617 |
+
api_url = "https://api.toonify.photos/caption" # Example public API
|
618 |
+
files = {"image": ("image.jpg", image_bytes, "image/jpeg")}
|
619 |
+
|
620 |
+
response = requests.post(api_url, files=files, timeout=15)
|
621 |
+
if response.status_code == 200:
|
622 |
+
result = response.json()
|
623 |
+
return f"CAPTION: {result.get('caption', 'No caption generated')}"
|
624 |
+
else:
|
625 |
+
return "CAPTION: Could not generate caption (API error)"
|
626 |
+
except Exception as e:
|
627 |
+
return f"CAPTION: Image appears to be a {detect_simple_content(image_content)}"
|
628 |
+
|
629 |
+
def tag_image(image_content: bytes) -> str:
|
630 |
+
"""Generate tags for an image."""
|
631 |
+
try:
|
632 |
+
# Check if we have HF API key in environment
|
633 |
+
hf_api_key = os.getenv("HUGGINGFACE_API_TOKEN")
|
634 |
+
|
635 |
+
if hf_api_key:
|
636 |
+
# Use Hugging Face API for image tagging
|
637 |
+
api_url = "https://api-inference.huggingface.co/models/google/vit-base-patch16-224"
|
638 |
+
headers = {"Authorization": f"Bearer {hf_api_key}"}
|
639 |
+
|
640 |
+
# Send image as binary content
|
641 |
+
response = requests.post(api_url, headers=headers, data=image_content)
|
642 |
+
if response.status_code == 200:
|
643 |
+
tags = response.json()
|
644 |
+
# Format results
|
645 |
+
formatted_tags = "\n".join([f"- {tag['label']} ({tag['score']:.2%})" for tag in tags[:10]])
|
646 |
+
return f"TAGS:\n{formatted_tags}"
|
647 |
+
else:
|
648 |
+
# Fallback to basic detection
|
649 |
+
return f"TAGS:\n- {detect_simple_content(image_content)}"
|
650 |
+
else:
|
651 |
+
# No API key
|
652 |
+
return f"TAGS:\n- {detect_simple_content(image_content)}"
|
653 |
+
except Exception as e:
|
654 |
+
return f"Error generating tags: {str(e)}"
|
655 |
+
|
656 |
+
def detect_objects(image_content: bytes) -> str:
|
657 |
+
"""Detect objects in an image."""
|
658 |
+
try:
|
659 |
+
# Check if we have HF API key in environment
|
660 |
+
hf_api_key = os.getenv("HUGGINGFACE_API_TOKEN")
|
661 |
+
|
662 |
+
if hf_api_key:
|
663 |
+
# Use Hugging Face API for object detection
|
664 |
+
api_url = "https://api-inference.huggingface.co/models/facebook/detr-resnet-50"
|
665 |
+
headers = {"Authorization": f"Bearer {hf_api_key}"}
|
666 |
+
|
667 |
+
# Send image as binary content
|
668 |
+
response = requests.post(api_url, headers=headers, data=image_content)
|
669 |
+
if response.status_code == 200:
|
670 |
+
objects = response.json()
|
671 |
+
|
672 |
+
# Count objects by label
|
673 |
+
object_counts = {}
|
674 |
+
for obj in objects:
|
675 |
+
label = obj["label"]
|
676 |
+
if label in object_counts:
|
677 |
+
object_counts[label] += 1
|
678 |
+
else:
|
679 |
+
object_counts[label] = 1
|
680 |
+
|
681 |
+
# Format results
|
682 |
+
formatted_objects = "\n".join([f"- {count}× {label}" for label, count in object_counts.items()])
|
683 |
+
return f"OBJECTS DETECTED:\n{formatted_objects}"
|
684 |
+
else:
|
685 |
+
return "OBJECTS: Could not detect objects (API error)"
|
686 |
+
else:
|
687 |
+
return "OBJECTS: API key required for object detection"
|
688 |
+
except Exception as e:
|
689 |
+
return f"Error detecting objects: {str(e)}"
|
690 |
+
|
691 |
+
def detect_simple_content(image_content: bytes) -> str:
|
692 |
+
"""Simple function to detect basic image type when APIs are not available."""
|
693 |
+
try:
|
694 |
+
image = Image.open(BytesIO(image_content))
|
695 |
+
width, height = image.size
|
696 |
+
aspect = width / height
|
697 |
+
|
698 |
+
# Very simple heuristics
|
699 |
+
if aspect > 2:
|
700 |
+
return "panorama or banner image"
|
701 |
+
elif aspect < 0.5:
|
702 |
+
return "tall or portrait image"
|
703 |
+
elif width < 300 or height < 300:
|
704 |
+
return "small thumbnail or icon"
|
705 |
+
else:
|
706 |
+
return "photograph or general image"
|
707 |
+
except:
|
708 |
+
return "image (could not analyze format)"
|
709 |
+
|
710 |
+
import contextlib
|
711 |
+
from io import StringIO
|
712 |
+
|
713 |
+
@tool
|
714 |
+
def python_repl(code: str) -> str:
|
715 |
+
"""Execute Python code and return the result.
|
716 |
+
|
717 |
+
Args:
|
718 |
+
code: Python code to execute
|
719 |
+
"""
|
720 |
+
if not code or not isinstance(code, str):
|
721 |
+
return "Error: Please provide valid Python code as a string."
|
722 |
+
|
723 |
+
try:
|
724 |
+
# Create a secure dict of globals with limited builtins
|
725 |
+
restricted_globals = {
|
726 |
+
"__builtins__": {
|
727 |
+
k: __builtins__[k] for k in [
|
728 |
+
'abs', 'all', 'any', 'bool', 'chr', 'dict', 'dir', 'divmod',
|
729 |
+
'enumerate', 'filter', 'float', 'format', 'frozenset', 'hash',
|
730 |
+
'hex', 'int', 'isinstance', 'len', 'list', 'map', 'max',
|
731 |
+
'min', 'oct', 'ord', 'pow', 'print', 'range', 'repr',
|
732 |
+
'round', 'set', 'slice', 'sorted', 'str', 'sum', 'tuple', 'type', 'zip'
|
733 |
+
] if k in __builtins__
|
734 |
+
}
|
735 |
+
}
|
736 |
+
|
737 |
+
# Add common math functions
|
738 |
+
import math
|
739 |
+
for name in ['sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'sqrt',
|
740 |
+
'log', 'log10', 'exp', 'pi', 'e', 'ceil', 'floor', 'degrees', 'radians']:
|
741 |
+
if hasattr(math, name):
|
742 |
+
restricted_globals[name] = getattr(math, name)
|
743 |
+
|
744 |
+
# Local namespace for variables
|
745 |
+
local_vars = {}
|
746 |
+
|
747 |
+
# Capture stdout
|
748 |
+
stdout_capture = StringIO()
|
749 |
+
|
750 |
+
# Execute the code
|
751 |
+
with contextlib.redirect_stdout(stdout_capture):
|
752 |
+
try:
|
753 |
+
# Try to evaluate as an expression first
|
754 |
+
result = eval(code, restricted_globals, local_vars)
|
755 |
+
stdout_content = stdout_capture.getvalue().strip()
|
756 |
+
|
757 |
+
if stdout_content:
|
758 |
+
return f"{stdout_content}\nResult: {result}"
|
759 |
+
return f"Result: {result}"
|
760 |
+
except SyntaxError:
|
761 |
+
# Not an expression, try executing as statements
|
762 |
+
exec(code, restricted_globals, local_vars)
|
763 |
+
stdout_content = stdout_capture.getvalue().strip()
|
764 |
+
|
765 |
+
if stdout_content:
|
766 |
+
return stdout_content
|
767 |
+
return "Code executed successfully with no output."
|
768 |
+
|
769 |
+
except Exception as e:
|
770 |
+
return f"Error executing code: {type(e).__name__}: {str(e)}"
|
771 |
+
|
772 |
|
773 |
|
chain_of_thought.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langgraph import Tool, Memory, Agent
|
2 |
+
from langgraph.tools import WebSearchTool, CalculatorTool
|
3 |
+
|
4 |
+
# Define tools the agent can use
|
5 |
+
tools = [
|
6 |
+
WebSearchTool(name="web_search",
|
7 |
+
description="Useful for searching the web"),
|
8 |
+
CalculatorTool(name="calculator",
|
9 |
+
description="Useful for arithmetic calculations")
|
10 |
+
]
|
11 |
+
|
12 |
+
# Set up simple memory (e.g., conversation history)
|
13 |
+
memory = Memory(max_tokens=500)
|
14 |
+
|
15 |
+
# Create the agent with reasoning and action capabilities
|
16 |
+
agent = Agent(
|
17 |
+
name="SimpleReasoningAgent",
|
18 |
+
tools=tools,
|
19 |
+
memory=memory,
|
20 |
+
reasoning_chain="explicit", # use explicit chain-of-thought
|
21 |
+
action_threshold=0.7 # confidence threshold to trigger actions
|
22 |
+
)
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
print("Welcome to SimpleReasoningAgent! Type 'exit' to quit.")
|
26 |
+
while True:
|
27 |
+
user_input = input("\nUser: ")
|
28 |
+
if user_input.lower() in ("exit", "quit"):
|
29 |
+
print("Goodbye!")
|
30 |
+
break
|
31 |
+
|
32 |
+
# Agent processes input
|
33 |
+
response = agent.run(user_input)
|
34 |
+
print(f"Agent: {response}\n")
|
react_agent.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from basic_tools import *
|
2 |
+
|
3 |
+
from langgraph.prebuilt import create_react_agent
|
4 |
+
from utils import *
|
5 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
6 |
+
|
7 |
+
|
8 |
+
# Initial System message
|
9 |
+
system_message = SystemMessage(content="You are a helpful assistant. You are free to utilize the tools present and give back proper answer")
|
10 |
+
|
11 |
+
|
12 |
+
def main(search_query: str = "What is the capital of France?") -> None:
|
13 |
+
# Initialize the LLM (loaded from the lmstudio server running on localhost:1234)
|
14 |
+
llm = get_llm(provider="openai_local")
|
15 |
+
if llm:
|
16 |
+
web_search_tools = [multiply,
|
17 |
+
multiply, add, subtract, divide, modulus,
|
18 |
+
wiki_search, web_search, arxiv_search,
|
19 |
+
python_repl, analyze_image,
|
20 |
+
date_filter, analyze_content,
|
21 |
+
step_by_step_reasoning, translate_text
|
22 |
+
]
|
23 |
+
# Create a langgraph react agent with the LLM and tools.
|
24 |
+
web_search_agent = create_react_agent(
|
25 |
+
name="Web Search Agent",
|
26 |
+
model=llm.bind(system_message=system_message),
|
27 |
+
tools=web_search_tools,
|
28 |
+
response_format={
|
29 |
+
"title": "SearchResults",
|
30 |
+
"description": "Structured JSON object with search results",
|
31 |
+
"type": "object",
|
32 |
+
"properties": {
|
33 |
+
"results": {
|
34 |
+
"type": "array",
|
35 |
+
"items": {"type": "string"}
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"required": ["results"]
|
39 |
+
}
|
40 |
+
)
|
41 |
+
|
42 |
+
# Provide a complete conversation history containing both a system and an initial user message.
|
43 |
+
# This allows the agent to have a valid first user message. But the message can't be in the form of messages but should be in the form of a dict.
|
44 |
+
# input_payload = {
|
45 |
+
# "messages": [
|
46 |
+
# {"role": "system", "content": system_message.content},
|
47 |
+
# {"role": "user", "content": f"{search_query}"}
|
48 |
+
# ]
|
49 |
+
# }
|
50 |
+
input_payload = {"messages": [
|
51 |
+
system_message, HumanMessage(content=f"{search_query}")]}
|
52 |
+
results = web_search_agent.invoke(input_payload)
|
53 |
+
print(results)
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
main("can you find out what is the best place to visit in France")
|
reasoning_agent.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Simple Reasoning and Action Agent using LangGraph and LangChain
|
3 |
+
|
4 |
+
This agent follows a standard reasoning pattern:
|
5 |
+
1. Think - Analyze the input and determine an approach
|
6 |
+
2. Select - Choose appropriate tools from available options
|
7 |
+
3. Act - Use the selected tools
|
8 |
+
4. Observe - Review results
|
9 |
+
5. Conclude - Generate final response
|
10 |
+
"""
|
11 |
+
|
12 |
+
import os
|
13 |
+
from typing import Dict, List, Annotated, TypedDict, Union, Tuple, Any
|
14 |
+
|
15 |
+
from langchain_core.tools import BaseTool
|
16 |
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
17 |
+
from langchain_core.prompts import ChatPromptTemplate
|
18 |
+
from langchain.tools.render import format_tool_to_openai_function
|
19 |
+
from langchain_openai import ChatOpenAI
|
20 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
21 |
+
|
22 |
+
from langgraph.graph import StateGraph, END
|
23 |
+
from langgraph.prebuilt import ToolNode
|
24 |
+
|
25 |
+
|
26 |
+
from basic_tools import *
|
27 |
+
from utils import *
|
28 |
+
|
29 |
+
def get_available_tools():
|
30 |
+
tools = [multiply,
|
31 |
+
multiply, add, subtract, divide, modulus,
|
32 |
+
wiki_search, web_search, arxiv_search,
|
33 |
+
python_repl, analyze_image,
|
34 |
+
date_filter, analyze_content,
|
35 |
+
step_by_step_reasoning, translate_text
|
36 |
+
]
|
37 |
+
return tools
|
38 |
+
|
39 |
+
|
40 |
+
# Define the agent state
|
41 |
+
class AgentState(TypedDict):
|
42 |
+
"""State for the reasoning and action agent."""
|
43 |
+
messages: List[Union[AIMessage, HumanMessage, SystemMessage, ToolMessage]]
|
44 |
+
# We'll store intermediate steps of reasoning here
|
45 |
+
reasoning: List[str]
|
46 |
+
# Keep track of selected tools
|
47 |
+
selected_tools: List[str]
|
48 |
+
# Store tool results
|
49 |
+
tool_results: Dict[str, Any]
|
50 |
+
|
51 |
+
|
52 |
+
#
|
53 |
+
model = get_llm(provider="openai")
|
54 |
+
|
55 |
+
# System prompts
|
56 |
+
AGENT_SYSTEM_PROMPT = """You are a helpful reasoning and action agent.
|
57 |
+
Your job is to:
|
58 |
+
1. Carefully analyze the user's request
|
59 |
+
2. Think through the problem step by step
|
60 |
+
3. Select appropriate tools from your toolkit
|
61 |
+
4. Use those tools to address the request
|
62 |
+
5. Provide a clear, complete response
|
63 |
+
|
64 |
+
Available tools:
|
65 |
+
{tool_descriptions}
|
66 |
+
|
67 |
+
When you need to use a tool, select the most appropriate one based on your reasoning.
|
68 |
+
Always show your reasoning process clearly.
|
69 |
+
"""
|
70 |
+
|
71 |
+
|
72 |
+
# ============= Node Functions =============
|
73 |
+
|
74 |
+
def think(state: AgentState) -> AgentState:
|
75 |
+
"""Think through the problem and analyze the user request."""
|
76 |
+
|
77 |
+
# Extract the user's most recent message
|
78 |
+
user_message = state["messages"][-1]
|
79 |
+
if not isinstance(user_message, HumanMessage):
|
80 |
+
# If the last message isn't from the user, find the most recent one
|
81 |
+
for msg in reversed(state["messages"]):
|
82 |
+
if isinstance(msg, HumanMessage):
|
83 |
+
user_message = msg
|
84 |
+
break
|
85 |
+
|
86 |
+
# Create a prompt for thinking
|
87 |
+
think_prompt = ChatPromptTemplate.from_messages([
|
88 |
+
SystemMessage(
|
89 |
+
content="You are analyzing a user request. Think step by step about what the user is asking for and what approach would be best."),
|
90 |
+
("user", "{input}")
|
91 |
+
])
|
92 |
+
|
93 |
+
# Generate thinking output
|
94 |
+
think_response = model.invoke(
|
95 |
+
think_prompt.format_messages(input=user_message.content)
|
96 |
+
)
|
97 |
+
|
98 |
+
# Update state with reasoning
|
99 |
+
reasoning = think_response.content
|
100 |
+
state["reasoning"] = state.get("reasoning", []) + [reasoning]
|
101 |
+
|
102 |
+
return state
|
103 |
+
|
104 |
+
|
105 |
+
def select_tools(state: AgentState) -> AgentState:
|
106 |
+
"""Select appropriate tools based on the reasoning."""
|
107 |
+
|
108 |
+
# Get available tools
|
109 |
+
tools = get_available_tools()
|
110 |
+
tool_descriptions = "\n".join(
|
111 |
+
[f"- {tool.name}: {tool.description}" for tool in tools])
|
112 |
+
|
113 |
+
# Create a prompt for tool selection
|
114 |
+
select_prompt = ChatPromptTemplate.from_messages([
|
115 |
+
SystemMessage(content=f"""Based on your analysis, select which tools would be most helpful for this task.
|
116 |
+
Available tools:
|
117 |
+
{tool_descriptions}
|
118 |
+
|
119 |
+
Return your selection as a comma-separated list of tool names, e.g., "calculator,web_search".
|
120 |
+
Only include tools that are actually needed for this specific request."""),
|
121 |
+
("user", "{reasoning}")
|
122 |
+
])
|
123 |
+
|
124 |
+
# Generate tool selection output
|
125 |
+
select_response = model.invoke(
|
126 |
+
select_prompt.format_messages(reasoning=state["reasoning"][-1])
|
127 |
+
)
|
128 |
+
|
129 |
+
# Parse the selected tools
|
130 |
+
selected_tools = [
|
131 |
+
tool_name.strip()
|
132 |
+
for tool_name in select_response.content.split(',')
|
133 |
+
]
|
134 |
+
|
135 |
+
# Filter to ensure only valid tools are selected
|
136 |
+
valid_tool_names = [tool.name for tool in tools]
|
137 |
+
selected_tools = [
|
138 |
+
tool for tool in selected_tools if tool in valid_tool_names]
|
139 |
+
|
140 |
+
# Update state with selected tools
|
141 |
+
state["selected_tools"] = selected_tools
|
142 |
+
|
143 |
+
# Add a single AIMessage with all tool calls (if any tools selected)
|
144 |
+
if selected_tools:
|
145 |
+
tool_calls = [
|
146 |
+
{"id": f"call_{i}", "name": tool_name, "args": {}}
|
147 |
+
for i, tool_name in enumerate(selected_tools)
|
148 |
+
]
|
149 |
+
state["messages"].append(
|
150 |
+
AIMessage(
|
151 |
+
content="",
|
152 |
+
tool_calls=tool_calls
|
153 |
+
)
|
154 |
+
)
|
155 |
+
|
156 |
+
return state
|
157 |
+
|
158 |
+
|
159 |
+
# def execute_tools(state: AgentState) -> AgentState:
|
160 |
+
# """Execute the selected tools."""
|
161 |
+
|
162 |
+
# # Get all available tools
|
163 |
+
# all_tools = get_available_tools()
|
164 |
+
|
165 |
+
# # Filter to only use selected tools
|
166 |
+
# selected_tool_names = state["selected_tools"]
|
167 |
+
# tools_to_use = [
|
168 |
+
# tool for tool in all_tools if tool.name in selected_tool_names]
|
169 |
+
|
170 |
+
# # Create tool executor
|
171 |
+
# tool_executor = ToolExecutor(tools_to_use)
|
172 |
+
|
173 |
+
# # Get the most recent reasoning
|
174 |
+
# reasoning = state["reasoning"][-1]
|
175 |
+
|
176 |
+
# # For each tool, generate a specific input and execute
|
177 |
+
# tool_results = {}
|
178 |
+
# for tool in tools_to_use:
|
179 |
+
# # Create prompt for generating tool input
|
180 |
+
# tool_input_prompt = ChatPromptTemplate.from_messages([
|
181 |
+
# SystemMessage(content=f"""Generate a specific input for the following tool:
|
182 |
+
# Tool: {tool.name}
|
183 |
+
# Description: {tool.description}
|
184 |
+
|
185 |
+
# The input should be formatted according to the tool's requirements and contain all necessary information.
|
186 |
+
# Return only the exact input string that should be passed to the tool, nothing else."""),
|
187 |
+
# ("user", "{reasoning}")
|
188 |
+
# ])
|
189 |
+
|
190 |
+
# # Generate specific input for this tool
|
191 |
+
# tool_input_response = model.invoke(
|
192 |
+
# tool_input_prompt.format_messages(reasoning=reasoning)
|
193 |
+
# )
|
194 |
+
# tool_input = tool_input_response.content.strip()
|
195 |
+
|
196 |
+
# try:
|
197 |
+
# # Execute the tool with the generated input
|
198 |
+
# result = tool_executor.invoke({tool.name: tool_input})
|
199 |
+
# tool_results[tool.name] = result[tool.name]
|
200 |
+
|
201 |
+
# # Add tool message to conversation
|
202 |
+
# state["messages"].append(
|
203 |
+
# ToolMessage(content=str(result[tool.name]), name=tool.name)
|
204 |
+
# )
|
205 |
+
# except Exception as e:
|
206 |
+
# # Handle errors
|
207 |
+
# tool_results[tool.name] = f"Error executing tool: {str(e)}"
|
208 |
+
# state["messages"].append(
|
209 |
+
# ToolMessage(
|
210 |
+
# content=f"Error executing tool: {str(e)}", name=tool.name)
|
211 |
+
# )
|
212 |
+
|
213 |
+
# # Update state with tool results
|
214 |
+
# state["tool_results"] = tool_results
|
215 |
+
|
216 |
+
# return state
|
217 |
+
|
218 |
+
|
219 |
+
def generate_response(state: AgentState) -> AgentState:
|
220 |
+
"""Generate a final response based on reasoning and tool outputs."""
|
221 |
+
|
222 |
+
# Prepare the context for response generation
|
223 |
+
tool_outputs = "\n".join([
|
224 |
+
f"{tool_name}: {result}"
|
225 |
+
for tool_name, result in state.get("tool_results", {}).items()
|
226 |
+
])
|
227 |
+
|
228 |
+
# Create prompt for response generation
|
229 |
+
response_prompt = ChatPromptTemplate.from_messages([
|
230 |
+
SystemMessage(content="""Generate a helpful response to the user based on your reasoning and tool outputs.
|
231 |
+
Be thorough but concise. Focus on directly answering the user's request.
|
232 |
+
If tools provided relevant information, incorporate it into your response."""),
|
233 |
+
("user",
|
234 |
+
"User request: {user_request}\n\nReasoning: {reasoning}\n\nTool outputs: {tool_outputs}")
|
235 |
+
])
|
236 |
+
|
237 |
+
# Get original user request
|
238 |
+
user_request = None
|
239 |
+
for msg in reversed(state["messages"]):
|
240 |
+
if isinstance(msg, HumanMessage):
|
241 |
+
user_request = msg.content
|
242 |
+
break
|
243 |
+
|
244 |
+
# Generate final response
|
245 |
+
response = model.invoke(
|
246 |
+
response_prompt.format_messages(
|
247 |
+
user_request=user_request,
|
248 |
+
reasoning=state["reasoning"][-1],
|
249 |
+
tool_outputs=tool_outputs
|
250 |
+
)
|
251 |
+
)
|
252 |
+
|
253 |
+
# Add the AI response to messages
|
254 |
+
state["messages"].append(AIMessage(content=response.content))
|
255 |
+
|
256 |
+
return state
|
257 |
+
|
258 |
+
|
259 |
+
# ============= Graph Definition =============
|
260 |
+
|
261 |
+
def create_agent_graph():
|
262 |
+
"""Create and configure the agent graph."""
|
263 |
+
|
264 |
+
graph = StateGraph(AgentState)
|
265 |
+
|
266 |
+
graph.add_node("think", think)
|
267 |
+
graph.add_node("select_tools", select_tools)
|
268 |
+
|
269 |
+
tools = get_available_tools()
|
270 |
+
tool_node = ToolNode(tools)
|
271 |
+
graph.add_node("execute_tools", tool_node)
|
272 |
+
|
273 |
+
graph.add_node("generate_response", generate_response)
|
274 |
+
|
275 |
+
# Conditional edge: if no tools, skip execute_tools
|
276 |
+
def select_tools_next(state: AgentState):
|
277 |
+
if state["selected_tools"]:
|
278 |
+
return "execute_tools"
|
279 |
+
else:
|
280 |
+
return "generate_response"
|
281 |
+
|
282 |
+
graph.add_edge("think", "select_tools")
|
283 |
+
graph.add_conditional_edges("select_tools", select_tools_next)
|
284 |
+
graph.add_edge("execute_tools", "generate_response")
|
285 |
+
graph.add_edge("generate_response", END)
|
286 |
+
|
287 |
+
graph.set_entry_point("think")
|
288 |
+
return graph.compile()
|
289 |
+
|
290 |
+
|
291 |
+
# ============= Agent Interface =============
|
292 |
+
|
293 |
+
class ReasoningAgent:
|
294 |
+
"""Reasoning and action agent main class."""
|
295 |
+
|
296 |
+
def __init__(self):
|
297 |
+
self.graph = create_agent_graph()
|
298 |
+
# Initialize with system prompt
|
299 |
+
tools = get_available_tools()
|
300 |
+
tool_descriptions = "\n".join(
|
301 |
+
[f"- {tool.name}: {tool.description}" for tool in tools])
|
302 |
+
self.messages = [
|
303 |
+
SystemMessage(content=AGENT_SYSTEM_PROMPT.format(
|
304 |
+
tool_descriptions=tool_descriptions))
|
305 |
+
]
|
306 |
+
|
307 |
+
def invoke(self, user_input: str) -> str:
|
308 |
+
"""Process user input and return response."""
|
309 |
+
# Add user message to history
|
310 |
+
self.messages.append(HumanMessage(content=user_input))
|
311 |
+
|
312 |
+
# Initialize state
|
313 |
+
state = {"messages": self.messages, "reasoning": [],
|
314 |
+
"selected_tools": [], "tool_results": {}}
|
315 |
+
|
316 |
+
# Run the graph
|
317 |
+
result = self.graph.invoke(state)
|
318 |
+
|
319 |
+
# Update messages
|
320 |
+
self.messages = result["messages"]
|
321 |
+
|
322 |
+
# Return the last AI message
|
323 |
+
for msg in reversed(result["messages"]):
|
324 |
+
if isinstance(msg, AIMessage):
|
325 |
+
return msg.content
|
326 |
+
|
327 |
+
# Fallback
|
328 |
+
return "I encountered an issue processing your request."
|
329 |
+
|
330 |
+
def __call__(self,*args, **kwargs):
|
331 |
+
"""Invoke the agent with user input."""
|
332 |
+
return self.invoke(*args, **kwargs)
|
333 |
+
|
334 |
+
|
335 |
+
# Sample usage
|
336 |
+
if __name__ == "__main__":
|
337 |
+
agent = ReasoningAgent()
|
338 |
+
response = agent.invoke(
|
339 |
+
"What's the weather in New York today and should I take an umbrella?")
|
340 |
+
print(response)
|
system_prompt.txt
CHANGED
@@ -1,17 +1,43 @@
|
|
1 |
-
You are
|
2 |
|
3 |
-
|
4 |
-
|
5 |
|
6 |
-
|
|
|
7 |
|
8 |
-
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
-
|
14 |
-
|
15 |
-
- FINAL ANSWER: 128
|
16 |
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are an expert research and analysis assistant with access to specialized tools. Follow these instructions precisely:
|
2 |
|
3 |
+
STEP 1: ANALYZE THE QUESTION CAREFULLY
|
4 |
+
Before selecting tools, understand exactly what information is needed.
|
5 |
|
6 |
+
STEP 2: SELECT THE APPROPRIATE TOOLS
|
7 |
+
Choose tools based on what information you need:
|
8 |
|
9 |
+
SEARCH TOOLS:
|
10 |
+
- wiki_search: Get encyclopedia facts using specific queries
|
11 |
+
Example: {"name": "wiki_search", "parameters": {"query": "Mercedes Sosa discography"}}
|
12 |
+
- web_search: For current information and detailed explanations
|
13 |
+
- arxiv_search: For academic papers and research
|
14 |
|
15 |
+
ANALYSIS TOOLS:
|
16 |
+
- analyze_discography: Find albums by an artist in a specific time period
|
17 |
+
Example: {"name": "analyze_discography", "parameters": {"content": "...", "artist_name": "Mercedes Sosa", "start_year": 2000, "end_year": 2009}}
|
18 |
+
- date_filter: Extract content only from a specific time period
|
19 |
+
- analyze_content: Extract specific types of information (dates, names, numbers, events)
|
20 |
+
- step_by_step_reasoning: Break down complex problems into logical steps
|
21 |
|
22 |
+
MEDIA TOOLS:
|
23 |
+
- youtube_transcript: Extract and optionally summarize video content
|
24 |
+
Example: {"name": "youtube_transcript", "parameters": {"url": "https://www.youtube.com/watch?v=abc123", "summarize": true}}
|
|
|
25 |
|
26 |
+
LANGUAGE TOOLS:
|
27 |
+
- translate_text: Translate content to another language
|
28 |
+
|
29 |
+
MATH TOOLS:
|
30 |
+
- add, subtract, multiply, divide, modulus: For calculations
|
31 |
+
|
32 |
+
REQUEST TOOLS:
|
33 |
+
- requests_get: Make HTTP GET requests to external APIs
|
34 |
+
- requests_post: Make HTTP POST requests to external APIs
|
35 |
+
|
36 |
+
STEP 3: USE TOOLS WITH ALL REQUIRED PARAMETERS
|
37 |
+
Every tool requires specific parameters - never call a tool without all required parameters.
|
38 |
+
|
39 |
+
STEP 4: PROVIDE YOUR FINAL ANSWER
|
40 |
+
After gathering information with tools, provide your answer:
|
41 |
+
FINAL ANSWER: [Your concise, factual answer]
|
42 |
+
|
43 |
+
Remember: If you're asked about albums, songs, or artists in specific time periods, use wiki_search first, then analyze_discography with appropriate date parameters.
|
tool_calling_agent.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from smolagents import ToolCallingAgent
|
2 |
+
from utils import *
|
3 |
+
from basic_tools import *
|
4 |
+
from smolagents.tools import Tool
|
5 |
+
langchain_tools = [multiply,
|
6 |
+
multiply, add, subtract, divide, modulus,
|
7 |
+
wiki_search, web_search, arxiv_search,
|
8 |
+
python_repl, analyze_image,
|
9 |
+
date_filter, analyze_content,
|
10 |
+
step_by_step_reasoning, translate_text
|
11 |
+
]
|
12 |
+
tools = [Tool.from_langchain(tool) for tool in langchain_tools]
|
13 |
+
agent = ToolCallingAgent(
|
14 |
+
model = get_llm(),
|
15 |
+
tools = tools,)
|
16 |
+
|
utils.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace, HuggingFaceEmbeddings
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from langchain_openai import ChatOpenAI
|
6 |
+
from langchain_groq import ChatGroq
|
7 |
+
from langchain.chat_models.base import BaseChatModel
|
8 |
+
from langchain_chroma import Chroma
|
9 |
+
|
10 |
+
|
11 |
+
def get_llm(provider: str = "groq") -> BaseChatModel:
|
12 |
+
# Load environment variables from .env file
|
13 |
+
if provider == "groq":
|
14 |
+
# Groq https://console.groq.com/docs/models
|
15 |
+
# optional : qwen-qwq-32b gemma2-9b-it
|
16 |
+
llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
|
17 |
+
elif provider == "huggingface":
|
18 |
+
# TODO: Add huggingface endpoint
|
19 |
+
llm = ChatHuggingFace(
|
20 |
+
llm=HuggingFaceEndpoint(
|
21 |
+
model="Meta-DeepLearning/llama-2-7b-chat-hf",
|
22 |
+
temperature=0,
|
23 |
+
),
|
24 |
+
)
|
25 |
+
elif provider == "openai_local":
|
26 |
+
from langchain_openai import ChatOpenAI
|
27 |
+
llm = ChatOpenAI(
|
28 |
+
base_url="http://localhost:11432/v1", # default LM Studio endpoint
|
29 |
+
api_key="not-used", # required by interface but ignored #type: ignore
|
30 |
+
# model="mistral-nemo-instruct-2407",
|
31 |
+
model="mistral-nemo-instruct-2407",
|
32 |
+
temperature=0.2
|
33 |
+
)
|
34 |
+
elif provider == "openai":
|
35 |
+
from langchain_openai import ChatOpenAI
|
36 |
+
llm = ChatOpenAI(
|
37 |
+
model="gpt-4o",
|
38 |
+
temperature=0.2,
|
39 |
+
)
|
40 |
+
else:
|
41 |
+
raise ValueError(
|
42 |
+
"Invalid provider. Choose 'groq' or 'huggingface'.")
|
43 |
+
return llm
|
44 |
+
|
45 |
+
|
46 |
+
embeddings = HuggingFaceEmbeddings(
|
47 |
+
model_name="sentence-transformers/all-mpnet-base-v2",
|
48 |
+
model_kwargs={"device": "gpu" if torch.cuda.is_available() else "cpu",
|
49 |
+
"token": os.getenv("HF_TOKEN")},
|
50 |
+
show_progress=True,
|
51 |
+
)
|
52 |
+
|
53 |
+
# Initialize empty Chroma vector store
|
54 |
+
vector_store = Chroma(
|
55 |
+
embedding_function=embeddings,
|
56 |
+
collection_name="agent_memory"
|
57 |
+
)
|