File size: 5,810 Bytes
e9514ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from typing import Optional
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import START, StateGraph, MessagesState
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI
# from langchain_community.utilities import GoogleSerperAPIWrapper
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_community.tools.tavily_search.tool import TavilySearchResults
from langchain_community.tools import WikipediaQueryRun
from langchain_core.tools import tool
from langchain.tools import Tool
import requests
import os
from pathlib import Path
import tempfile
system_prompt = """
You are a general AI assistant. I will ask you a question. If necessary, use the tools to seek the correct answer. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
"""
# Initialize tools
# search = GoogleSerperAPIWrapper()
search = TavilySearchResults(max_results=5)
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
# Custom tools
@tool
def download_image(url: str, filename: Optional[str] = None) -> str:
"""Download an image from a URL and save it to a file.
Args:
url: The URL of the image to download.
filename: Optional; The name to save the file as. If not provided, a name will be derived from the URL.
"""
try:
# Send a GET request to the URL
response = requests.get(url, stream=True)
response.raise_for_status() # Raise an exception for HTTP errors
# Determine the file extension from the content type
content_type = response.headers.get('Content-Type', '')
ext = '.jpg' # Default extension
if 'png' in content_type:
ext = '.png'
elif 'jpeg' in content_type or 'jpg' in content_type:
ext = '.jpg'
elif 'gif' in content_type:
ext = '.gif'
elif 'webp' in content_type:
ext = '.webp'
# Create a filename if not provided
if not filename:
# Use the last part of the URL or a timestamp
url_path = Path(url.split('?')[0])
url_filename = url_path.name
if url_filename:
filename = url_filename
if not Path(filename).suffix:
filename = f"{filename}{ext}"
else:
import time
filename = f"image_{int(time.time())}{ext}"
# Ensure the file has an extension
if not Path(filename).suffix:
filename = f"{filename}{ext}"
# Save the file
filepath = os.path.join(tempfile.gettempdir(), filename)
with open(filepath, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return f"Image downloaded and saved to {filepath}"
except Exception as e:
return f"Error downloading image: {str(e)}"
@tool
def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
"""
Save content to a file and return the path.
Args:
content (str): the content to save to the file
filename (str, optional): the name of the file. If not provided, a random name file will be created.
"""
temp_dir = tempfile.gettempdir()
if filename is None:
temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
filepath = temp_file.name
else:
filepath = os.path.join(temp_dir, filename)
with open(filepath, "w") as f:
f.write(content)
return f"File saved to {filepath}. You can read this file to process its contents."
tools = [
save_and_read_file,
download_image,
Tool(
name="web_search",
description="useful when you need to ask with search",
func=search.run,
),
Tool(
name="wikipedia",
description="useful when you need wikipedia to answer questions.",
func=wikipedia.run,
),
]
def build_graph():
system_message = SystemMessage(content=system_prompt)
llm_model = ChatOpenAI(model="gpt-4o", temperature=0)
llm_with_tools = llm_model.bind_tools(tools)
def assistant(state: MessagesState):
messages = [system_message] + state["messages"]
response = llm_with_tools.invoke(messages)
return {"messages": state["messages"] + [response]}
builder = StateGraph(MessagesState)
# Nodes
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
# Edges
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
# Compile graph
agent = builder.compile()
return agent
if __name__ == "__main__":
question = "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?"
graph = build_graph()
messages = [HumanMessage(content=question)]
messages = graph.invoke({"messages": messages})
for m in messages["messages"]:
m.pretty_print()
|