File size: 5,810 Bytes
e9514ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from typing import Optional
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import START, StateGraph, MessagesState
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI
# from langchain_community.utilities import GoogleSerperAPIWrapper
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_community.tools.tavily_search.tool import TavilySearchResults
from langchain_community.tools import WikipediaQueryRun
from langchain_core.tools import tool
from langchain.tools import Tool

import requests
import os
from pathlib import Path
import tempfile

system_prompt = """
    You are a general AI assistant. I will ask you a question. If necessary, use the tools to seek the correct answer. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
"""

# Initialize tools
# search = GoogleSerperAPIWrapper()
search = TavilySearchResults(max_results=5)
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())

# Custom tools
@tool
def download_image(url: str, filename: Optional[str] = None) -> str:
    """Download an image from a URL and save it to a file.
    Args:
        url: The URL of the image to download.
        filename: Optional; The name to save the file as. If not provided, a name will be derived from the URL.
    """
    try:
        # Send a GET request to the URL
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Raise an exception for HTTP errors
        
        # Determine the file extension from the content type
        content_type = response.headers.get('Content-Type', '')
        ext = '.jpg'  # Default extension
        if 'png' in content_type:
            ext = '.png'
        elif 'jpeg' in content_type or 'jpg' in content_type:
            ext = '.jpg'
        elif 'gif' in content_type:
            ext = '.gif'
        elif 'webp' in content_type:
            ext = '.webp'
        
        # Create a filename if not provided
        if not filename:
            # Use the last part of the URL or a timestamp
            url_path = Path(url.split('?')[0])
            url_filename = url_path.name
            if url_filename:
                filename = url_filename
                if not Path(filename).suffix:
                    filename = f"{filename}{ext}"
            else:
                import time
                filename = f"image_{int(time.time())}{ext}"
        
        # Ensure the file has an extension
        if not Path(filename).suffix:
            filename = f"{filename}{ext}"
        
        # Save the file
        filepath = os.path.join(tempfile.gettempdir(), filename)
        with open(filepath, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        
        return f"Image downloaded and saved to {filepath}"
    except Exception as e:
        return f"Error downloading image: {str(e)}"
    

@tool
def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
    """
    Save content to a file and return the path.
    Args:
        content (str): the content to save to the file
        filename (str, optional): the name of the file. If not provided, a random name file will be created.
    """
    temp_dir = tempfile.gettempdir()
    if filename is None:
        temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
        filepath = temp_file.name
    else:
        filepath = os.path.join(temp_dir, filename)

    with open(filepath, "w") as f:
        f.write(content)

    return f"File saved to {filepath}. You can read this file to process its contents."

tools = [
    save_and_read_file,
    download_image,
    Tool(
        name="web_search",
        description="useful when you need to ask with search",
        func=search.run,
    ),
    Tool(
        name="wikipedia",
        description="useful when you need wikipedia to answer questions.",
        func=wikipedia.run,
    ),
]


def build_graph():
    system_message = SystemMessage(content=system_prompt)
    llm_model = ChatOpenAI(model="gpt-4o", temperature=0)
    llm_with_tools = llm_model.bind_tools(tools)
    
    def assistant(state: MessagesState):
        messages = [system_message] + state["messages"]
        response = llm_with_tools.invoke(messages)
        return {"messages": state["messages"] + [response]}

    builder = StateGraph(MessagesState)
    # Nodes
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    # Edges
    builder.add_edge(START, "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition,
    )
    builder.add_edge("tools", "assistant")

    # Compile graph
    agent = builder.compile()

    return agent

if __name__ == "__main__":
    question = "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?"
    graph = build_graph()
    messages = [HumanMessage(content=question)]
    messages = graph.invoke({"messages": messages})
    for m in messages["messages"]:
        m.pretty_print()