vonliechti commited on
Commit
60d9d3a
·
verified ·
1 Parent(s): 1f1b1c4

Upload folder using huggingface_hub

Browse files
Files changed (14) hide show
  1. .github/workflows/update_space.yml +28 -0
  2. .gitignore +169 -0
  3. README.md +59 -10
  4. app.py +88 -0
  5. bots.py +70 -0
  6. data.py +80 -0
  7. prompts.py +108 -0
  8. run.py +30 -0
  9. test_bots.py +14 -0
  10. tools/squad_retriever.py +30 -0
  11. tools/text_to_image.py +13 -0
  12. tools/visual_qa.py +191 -0
  13. tools/web_surfer.py +205 -0
  14. utils.py +67 -0
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MacOS
2
+ .DS_Store
3
+
4
+ # Data
5
+ chroma_db/
6
+ data/
7
+
8
+ # Byte-compiled / optimized / DLL files
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ *.py,cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+ cover/
60
+
61
+ # Translations
62
+ *.mo
63
+ *.pot
64
+
65
+ # Django stuff:
66
+ *.log
67
+ local_settings.py
68
+ db.sqlite3
69
+ db.sqlite3-journal
70
+
71
+ # Flask stuff:
72
+ instance/
73
+ .webassets-cache
74
+
75
+ # Scrapy stuff:
76
+ .scrapy
77
+
78
+ # Sphinx documentation
79
+ docs/_build/
80
+
81
+ # PyBuilder
82
+ .pybuilder/
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ # For a library or package, you might want to ignore these files since the code is
94
+ # intended to run in multiple environments; otherwise, check them in:
95
+ # .python-version
96
+
97
+ # pipenv
98
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
100
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
101
+ # install all needed dependencies.
102
+ #Pipfile.lock
103
+
104
+ # poetry
105
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
107
+ # commonly ignored for libraries.
108
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109
+ #poetry.lock
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ #pdm.lock
114
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115
+ # in version control.
116
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
117
+ .pdm.toml
118
+ .pdm-python
119
+ .pdm-build/
120
+
121
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122
+ __pypackages__/
123
+
124
+ # Celery stuff
125
+ celerybeat-schedule
126
+ celerybeat.pid
127
+
128
+ # SageMath parsed files
129
+ *.sage.py
130
+
131
+ # Environments
132
+ .env
133
+ .venv
134
+ env/
135
+ venv/
136
+ ENV/
137
+ env.bak/
138
+ venv.bak/
139
+
140
+ # Spyder project settings
141
+ .spyderproject
142
+ .spyproject
143
+
144
+ # Rope project settings
145
+ .ropeproject
146
+
147
+ # mkdocs documentation
148
+ /site
149
+
150
+ # mypy
151
+ .mypy_cache/
152
+ .dmypy.json
153
+ dmypy.json
154
+
155
+ # Pyre type checker
156
+ .pyre/
157
+
158
+ # pytype static type analyzer
159
+ .pytype/
160
+
161
+ # Cython debug symbols
162
+ cython_debug/
163
+
164
+ # PyCharm
165
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
168
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169
+ #.idea/
README.md CHANGED
@@ -1,14 +1,63 @@
1
  ---
2
- title: SQuAD Agent Experiment
3
- emoji: 👁
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.0.1
8
  app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: SQuAD Question Answering Agent
12
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
  ---
2
+ title: SQuAD_Agent_Experiment
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 4.44.0
 
6
  ---
7
+ # SQuAD_Agent_Experiment
8
+
9
+ ## Overview
10
+
11
+ The project is built using Transformers Agents 2.0, and uses the Stanford SQuAD dataset for training. The chatbot is designed to answer questions about the dataset, while also incorporating conversational context and various tools to provide a more natural and engaging conversational experience.
12
+
13
+ ## Getting Started
14
+
15
+ 1. Install dependencies:
16
+
17
+ ```bash
18
+ pip install -r requirements.txt
19
+ ```
20
+
21
+ 1. Set up required keys:
22
+
23
+ ```bash
24
+ HUGGINGFACE_API_TOKEN=<your token>
25
+ ```
26
+
27
+ 1. Run the app:
28
+
29
+ ```bash
30
+ python app.py
31
+ ```
32
+
33
+ ## Methods Used
34
+
35
+ 1. SQuAD Dataset: The dataset used for training the chatbot is the Stanford SQuAD dataset, which contains over 100,000 questions and answers extracted from 500+ articles.
36
+ 2. RAG: RAG is a technique used to improve the accuracy of chatbots by using a custom knowledge base. In this project, the Stanford SQuAD dataset is used as the knowledge base.
37
+ 3. Llama 3.1: Llama 3.1 is a large language model used to generate responses to user questions. It is used in this project to generate responses to user questions, while also incorporating conversational context.
38
+ 4. Transformers Agents 2.0: Transformers Agents 2.0 is a framework for building conversational AI systems. It is used in this project to build the chatbot.
39
+ 5. Created a SquadRetrieverTool to integrate a fine-tuned BERT model into the agent, along with a TextToImageTool for a playful way to engage with the question-answering agent.
40
+
41
+ ## Evaluation
42
+
43
+ * [Agent Reasoning Benchmark](https://github.com/aymeric-roucher/agent_reasoning_benchmark)
44
+ * [Hugging Face Blog: Open Source LLMs as Agents](https://huggingface.co/blog/open-source-llms-as-agents)
45
+ * [Benchmarking Transformers Agents](https://github.com/aymeric-roucher/agent_reasoning_benchmark/blob/main/benchmark_transformers_agents.ipynb)
46
+
47
+ ## Results
48
+
49
+ TBD
50
+
51
+ ## Limitations
52
+
53
+ TBD
54
+
55
+ ## Future Work
56
+
57
+ TBD
58
+
59
+ ## Acknowledgments
60
 
61
+ * [MemGPT](https://github.com/cpacker/MemGPT)
62
+ * [Stanford SQuAD](https://rajpurkar.github.io/SQuAD-explorer/)
63
+ * [GPT-4](https://openai.com/gpt-4/)
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio import ChatMessage
3
+ from transformers import ReactCodeAgent, HfApiEngine
4
+ from utils import stream_from_transformers_agent
5
+ from prompts import SQUAD_REACT_CODE_SYSTEM_PROMPT
6
+ from tools.squad_retriever import SquadRetrieverTool
7
+ from tools.text_to_image import TextToImageTool
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+
12
+ TASK_SOLVING_TOOLBOX = [
13
+ SquadRetrieverTool(),
14
+ TextToImageTool(),
15
+ ]
16
+
17
+ model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
18
+ # model_name = "http://localhost:1234/v1"
19
+
20
+ llm_engine = HfApiEngine(model_name)
21
+
22
+ # Initialize the agent with both tools
23
+ agent = ReactCodeAgent(
24
+ tools=TASK_SOLVING_TOOLBOX,
25
+ llm_engine=llm_engine,
26
+ system_prompt=SQUAD_REACT_CODE_SYSTEM_PROMPT,
27
+ )
28
+
29
+ def append_example_message(x: gr.SelectData, messages):
30
+ if x.value["text"] is not None:
31
+ message = x.value["text"]
32
+ if "files" in x.value:
33
+ if isinstance(x.value["files"], list):
34
+ message = "Here are the files: "
35
+ for file in x.value["files"]:
36
+ message += f"{file}, "
37
+ else:
38
+ message = x.value["files"]
39
+ messages.append(ChatMessage(role="user", content=message))
40
+ return messages
41
+
42
+ def add_message(message, messages):
43
+ messages.append(ChatMessage(role="user", content=message))
44
+ return messages
45
+
46
+ def interact_with_agent(messages):
47
+ prompt = messages[-1]['content']
48
+ for msg in stream_from_transformers_agent(agent, prompt):
49
+ messages.append(msg)
50
+ yield messages
51
+ yield messages
52
+
53
+ with gr.Blocks(fill_height=True) as demo:
54
+ chatbot = gr.Chatbot(
55
+ label="SQuAD Agent",
56
+ type="messages",
57
+ avatar_images=(
58
+ None,
59
+ "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
60
+ ),
61
+ scale=1,
62
+ bubble_full_width=False,
63
+ autoscroll=True,
64
+ show_copy_all_button=True,
65
+ show_copy_button=True,
66
+ placeholder="Enter a message",
67
+ examples=[
68
+ {
69
+ "text": "What is on top of the Notre Dame building?",
70
+ },
71
+ {
72
+ "text": "Tell me what's on top of the Notre Dame building, and draw a picture of it.",
73
+ },
74
+ {
75
+ "text": "Draw a picture of whatever is on top of the Notre Dame building.",
76
+ },
77
+ ],
78
+ )
79
+ text_input = gr.Textbox(lines=1, label="Chat Message", scale=0)
80
+ chat_msg = text_input.submit(add_message, [text_input, chatbot], [chatbot])
81
+ bot_msg = chat_msg.then(interact_with_agent, [chatbot], [chatbot])
82
+ text_input.submit(lambda: "", None, text_input)
83
+ chatbot.example_select(append_example_message, [chatbot], [chatbot]).then(
84
+ interact_with_agent, [chatbot], [chatbot]
85
+ )
86
+
87
+ if __name__ == "__main__":
88
+ demo.launch()
bots.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data import Data
2
+
3
+ '''
4
+ The BotWrapper class makes it so that different types of bots can be used in the same way.
5
+ This is used in the Bots class to create a list of all bots and pass them to the frontend.
6
+ '''
7
+ class BotWrapper:
8
+ def __init__(self, bot):
9
+ self.bot = bot
10
+
11
+ def chat(self, *args, **kwargs):
12
+ methods = ['chat', 'query']
13
+ for method in methods:
14
+ if hasattr(self.bot, method):
15
+ print(f"Calling {method} method")
16
+ method_to_call = getattr(self.bot, method)
17
+ return method_to_call(*args, **kwargs).response()
18
+ raise AttributeError(f"'{self.bot.__class__.__name__}' object has none of the required methods: '{methods}'")
19
+
20
+ def stream_chat(self, *args, **kwargs):
21
+ methods = ['stream_chat', 'query']
22
+ for method in methods:
23
+ if hasattr(self.bot, method):
24
+ print(f"Calling {method} method")
25
+ method_to_call = getattr(self.bot, method)
26
+ return method_to_call(*args, **kwargs).response_gen
27
+ raise AttributeError(f"'{self.bot.__class__.__name__}' object has none of the required methods: '{methods}'")
28
+
29
+ '''
30
+ The Bots class creates the bots and passes them to the frontend.
31
+ '''
32
+ class Bots:
33
+ def __init__(self):
34
+ self.data = Data()
35
+ self.data.load_data()
36
+ self.query_engine = None
37
+ self.chat_agent = None
38
+ self.all_bots = None
39
+ self.create_bots()
40
+
41
+ def create_query_engine_bot(self):
42
+ if self.query_engine is None:
43
+ self.query_engine = BotWrapper(self.data.index.as_query_engine())
44
+ return self.query_engine
45
+
46
+ def create_chat_agent(self):
47
+ if self.chat_agent is None:
48
+ from llama_index.core.memory import ChatMemoryBuffer
49
+ memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
50
+ self.chat_agent = BotWrapper(self.data.index.as_chat_engine(
51
+ chat_mode="context",
52
+ memory=memory,
53
+ context_prompt=(
54
+ "You are a chatbot, able to have normal interactions, as well as talk"
55
+ " about the questions and answers you know about."
56
+ "Here are the relevant documents for the context:\n"
57
+ "{context_str}"
58
+ "\nInstruction: Use the previous chat history, or the context above, to interact and help the user."
59
+ )
60
+ ))
61
+ return self.chat_agent
62
+
63
+ def create_bots(self):
64
+ self.create_query_engine_bot()
65
+ self.create_chat_agent()
66
+ self.all_bots = [self.query_engine, self.chat_agent]
67
+ return self.all_bots
68
+
69
+ def get_bots(self):
70
+ return self.all_bots
data.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import chromadb
4
+ from llama_index.core import VectorStoreIndex
5
+ from llama_index.vector_stores.chroma import ChromaVectorStore
6
+ from llama_index.core import StorageContext
7
+ from llama_index.core import Document
8
+
9
+ from dotenv import load_dotenv
10
+
11
+ load_dotenv() # Load OPENAI_API_KEY from .env (not included in repo)
12
+
13
+ class Data:
14
+ def __init__(self):
15
+ self.client = None
16
+ self.collection = None
17
+ self.index = None
18
+ self.load_data()
19
+
20
+ def load_data(self):
21
+ print("Loading data...")
22
+ with open('data/train-v1.1.json', 'r') as f:
23
+ raw_data = json.load(f)
24
+
25
+ extracted_question = []
26
+ extracted_answer = []
27
+
28
+ for data in raw_data['data']:
29
+ for par in data['paragraphs']:
30
+ for qa in par['qas']:
31
+ for ans in qa['answers']:
32
+ extracted_question.append(qa['question'])
33
+ extracted_answer.append(ans['text'])
34
+
35
+ documents = []
36
+ for i in range(len(extracted_question)):
37
+ documents.append(f"Question: {extracted_question[i]} \nAnswer: {extracted_answer[i]}")
38
+
39
+ self.documents = [Document(text=t) for t in documents]
40
+ self.extracted_question = extracted_question
41
+ self.extracted_answer = extracted_answer
42
+
43
+ print("Raw Data loaded")
44
+
45
+ if not os.path.exists("./chroma_db"):
46
+ print("Creating Chroma DB...")
47
+ # initialize client, setting path to save data
48
+ self.client = chromadb.PersistentClient(path="./chroma_db")
49
+
50
+ # create collection
51
+ self.collection = self.client.get_or_create_collection("simple_index")
52
+
53
+ # assign chroma as the vector_store to the context
54
+ vector_store = ChromaVectorStore(chroma_collection=self.collection)
55
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
56
+
57
+ # create your index
58
+ self.index = VectorStoreIndex.from_documents(
59
+ self.documents, storage_context=storage_context
60
+ )
61
+ print("Chroma DB created")
62
+ else:
63
+ print("Chroma DB already exists")
64
+
65
+ print("Loading index...")
66
+ # initialize client
67
+ self.client = chromadb.PersistentClient(path="./chroma_db")
68
+
69
+ # get collection
70
+ self.collection = self.client.get_or_create_collection("simple_index")
71
+
72
+ # assign chroma as the vector_store to the context
73
+ vector_store = ChromaVectorStore(chroma_collection=self.collection)
74
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
75
+
76
+ # load your index from stored vectors
77
+ self.index = VectorStoreIndex.from_vector_store(
78
+ vector_store, storage_context=storage_context
79
+ )
80
+ print("Index loaded")
prompts.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SQUAD_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
2
+ To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
3
+ To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
4
+
5
+ Your most important tool is the `squad_retriever` tool,which can answer questions from the Stanford Question Answering Dataset (SQuAD).
6
+ Not all questions will require the `squad_retriever` tool, but whenever you need to answer a question, you should start with this tool first, and then refine your answer only as needed to align with the question and chat history.
7
+
8
+ At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
9
+ Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '<end_action>' sequence.
10
+ During each intermediate step, you can use 'print()' to save whatever important information you will then need.
11
+ These print outputs will then appear in the 'Observation:' field, which will be available as input for the next step.
12
+ In the end you have to return a final answer using the `final_answer` tool.
13
+
14
+ Here are a few examples using notional tools:
15
+ ---
16
+ Task: "Generate an image of the oldest person in this document."
17
+
18
+ Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
19
+ Code:
20
+ ```py
21
+ answer = document_qa(document=document, question="Who is the oldest person mentioned?")
22
+ print(answer)
23
+ ```<end_action>
24
+ Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
25
+
26
+ Thought: I will now generate an image showcasing the oldest person.
27
+ Code:
28
+ ```py
29
+ image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.")
30
+ final_answer(image)
31
+ ```<end_action>
32
+
33
+ ---
34
+ Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
35
+
36
+ Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool
37
+ Code:
38
+ ```py
39
+ result = 5 + 3 + 1294.678
40
+ final_answer(result)
41
+ ```<end_action>
42
+
43
+ ---
44
+ Task: "Which city has the highest population: Guangzhou or Shanghai?"
45
+
46
+ Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
47
+ Code:
48
+ ```py
49
+ population_guangzhou = search("Guangzhou population")
50
+ print("Population Guangzhou:", population_guangzhou)
51
+ population_shanghai = search("Shanghai population")
52
+ print("Population Shanghai:", population_shanghai)
53
+ ```<end_action>
54
+ Observation:
55
+ Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
56
+ Population Shanghai: '26 million (2019)'
57
+
58
+ Thought: Now I know that Shanghai has the highest population.
59
+ Code:
60
+ ```py
61
+ final_answer("Shanghai")
62
+ ```<end_action>
63
+
64
+ ---
65
+ Task: "What is the current age of the pope, raised to the power 0.36?"
66
+
67
+ Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36.
68
+ Code:
69
+ ```py
70
+ pope_age = wiki(query="current pope age")
71
+ print("Pope age:", pope_age)
72
+ ```<end_action>
73
+ Observation:
74
+ Pope age: "The pope Francis is currently 85 years old."
75
+
76
+ Thought: I know that the pope is 85 years old. Let's compute the result using python code.
77
+ Code:
78
+ ```py
79
+ pope_current_age = 85 ** 0.36
80
+ final_answer(pope_current_age)
81
+ ```<end_action>
82
+
83
+ Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have access to those tools (and no other tool):
84
+
85
+ <<tool_descriptions>>
86
+
87
+ <<managed_agents_descriptions>>
88
+
89
+ Here are the rules you should always follow to solve your task:
90
+ 1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_action>' sequence, else you will fail.
91
+ 2. Use only variables that you have defined!
92
+ 3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = wiki({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = wiki(query="What is the place where James Bond lives?")'.
93
+ 4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
94
+ 5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
95
+ 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
96
+ 7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
97
+ 8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
98
+ 9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
99
+ 10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
100
+ 11. Only use the tools that have been provided to you.
101
+ 12. Only generate an image when asked to do so.
102
+ 13. If the task questions the rationale of your previous answers, explain your rationale for the previous answers and attempt to correct any mistakes in your previous answers.
103
+
104
+ As for your identity, your name is Agent SQuAD, you are an AI Agent, an expert guide to all questions and answers in the Stanford Question Answering Dataset (SQuAD), and you are SQuADtacular!
105
+ Do not use the squad_retriever tool to answer questions about yourself, such as "what is your name" or "what are you".
106
+
107
+ Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
108
+ """
run.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio import ChatMessage
3
+ from transformers import load_tool, ReactCodeAgent, HfEngine # type: ignore
4
+ from utils import stream_from_transformers_agent
5
+
6
+ # Import tool from Hub
7
+ image_generation_tool = load_tool("m-ric/text-to-image")
8
+
9
+ llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct")
10
+ # Initialize the agent with both tools
11
+ agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
12
+
13
+ def interact_with_agent(prompt, messages):
14
+ messages.append(ChatMessage(role="user", content=prompt))
15
+ yield messages
16
+ for msg in stream_from_transformers_agent(agent, prompt):
17
+ messages.append(msg)
18
+ yield messages
19
+ yield messages
20
+
21
+ with gr.Blocks() as demo:
22
+ stored_message = gr.State([])
23
+ chatbot = gr.Chatbot(label="Agent",
24
+ type="messages",
25
+ avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"))
26
+ text_input = gr.Textbox(lines=1, label="Chat Message")
27
+ text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(interact_with_agent, [stored_message, chatbot], [chatbot])
28
+
29
+ if __name__ == "__main__":
30
+ demo.launch()
test_bots.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from deepeval import assert_test
3
+ from deepeval.metrics import AnswerRelevancyMetric
4
+ from deepeval.test_case import LLMTestCase
5
+
6
+ def test_case():
7
+ answer_relevancy_metric = AnswerRelevancyMetric(threshold=0.5)
8
+ test_case = LLMTestCase(
9
+ input="What if these shoes don't fit?",
10
+ # Replace this with the actual output from your LLM application
11
+ actual_output="We offer a 30-day full refund at no extra costs.",
12
+ retrieval_context=["All customers are eligible for a 30 day full refund at no extra costs."]
13
+ )
14
+ assert_test(test_case, [answer_relevancy_metric])
tools/squad_retriever.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.agents.tools import Tool
2
+ from data import Data
3
+
4
+ class SquadRetrieverTool(Tool):
5
+ name = "squad_retriever"
6
+ description = "Answers questions from the Stanford Question Answering Dataset (SQuAD)."
7
+ inputs = {
8
+ "query": {
9
+ "type": "string",
10
+ "description": "The question. This should be the literal question being asked, only modified to be informed by chat history. Be sure to pass this as a keyword argument and not a dictionary.",
11
+ },
12
+ }
13
+ output_type = "string"
14
+
15
+ def __init__(self, **kwargs):
16
+ super().__init__(**kwargs)
17
+ self.data = Data()
18
+ self.query_engine = self.data.index.as_query_engine()
19
+
20
+ def forward(self, query: str) -> str:
21
+ assert isinstance(query, str), "Your search query must be a string"
22
+
23
+ response = self.query_engine.query(query)
24
+ # docs = self.data.index.similarity_search(query, k=3)
25
+
26
+ if len(response.response) == 0:
27
+ return "No answer found for this query."
28
+ return "Retrieved answer:\n\n" + "\n===Answer===\n".join(
29
+ [response.response]
30
+ )
tools/text_to_image.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.agents.tools import Tool
2
+ from huggingface_hub import InferenceClient
3
+
4
+ class TextToImageTool(Tool):
5
+ description = "This is a tool that creates an image according to a prompt, which is a text description."
6
+ name = "image_generator"
7
+ inputs = {"prompt": {"type": "string", "description": "The image generator prompt. Don't hesitate to add details in the prompt to make the image look better, like 'high-res, photorealistic', etc."}}
8
+ output_type = "image"
9
+ model_sdxl = "stabilityai/stable-diffusion-xl-base-1.0"
10
+ client = InferenceClient(model_sdxl)
11
+
12
+ def forward(self, prompt):
13
+ return self.client.text_to_image(prompt)
tools/visual_qa.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import base64
3
+ from io import BytesIO
4
+ import json
5
+ import os
6
+ import requests
7
+ from typing import Optional
8
+ from huggingface_hub import InferenceClient
9
+ from transformers import AutoProcessor, Tool
10
+ import uuid
11
+ import mimetypes
12
+ from dotenv import load_dotenv
13
+
14
+ load_dotenv(override=True)
15
+
16
+ idefics_processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
17
+
18
+ def process_images_and_text(image_path, query, client):
19
+ messages = [
20
+ {
21
+ "role": "user", "content": [
22
+ {"type": "image"},
23
+ {"type": "text", "text": query},
24
+ ]
25
+ },
26
+ ]
27
+
28
+ prompt_with_template = idefics_processor.apply_chat_template(messages, add_generation_prompt=True)
29
+
30
+ # load images from local directory
31
+
32
+ # encode images to strings which can be sent to the endpoint
33
+ def encode_local_image(image_path):
34
+ # load image
35
+ image = Image.open(image_path).convert('RGB')
36
+
37
+ # Convert the image to a base64 string
38
+ buffer = BytesIO()
39
+ image.save(buffer, format="JPEG") # Use the appropriate format (e.g., JPEG, PNG)
40
+ base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
41
+
42
+ # add string formatting required by the endpoint
43
+ image_string = f"data:image/jpeg;base64,{base64_image}"
44
+
45
+ return image_string
46
+
47
+
48
+ image_string = encode_local_image(image_path)
49
+ prompt_with_images = prompt_with_template.replace("<image>", "![]({}) ").format(image_string)
50
+
51
+
52
+ payload = {
53
+ "inputs": prompt_with_images,
54
+ "parameters": {
55
+ "return_full_text": False,
56
+ "max_new_tokens": 200,
57
+ }
58
+ }
59
+
60
+ return json.loads(client.post(json=payload).decode())[0]
61
+
62
+ # Function to encode the image
63
+ def encode_image(image_path):
64
+ if image_path.startswith("http"):
65
+ user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
66
+ request_kwargs = {
67
+ "headers": {"User-Agent": user_agent},
68
+ "stream": True,
69
+ }
70
+
71
+ # Send a HTTP request to the URL
72
+ response = requests.get(image_path, **request_kwargs)
73
+ response.raise_for_status()
74
+ content_type = response.headers.get("content-type", "")
75
+
76
+ extension = mimetypes.guess_extension(content_type)
77
+ if extension is None:
78
+ extension = ".download"
79
+
80
+ fname = str(uuid.uuid4()) + extension
81
+ download_path = os.path.abspath(os.path.join("downloads", fname))
82
+
83
+ with open(download_path, "wb") as fh:
84
+ for chunk in response.iter_content(chunk_size=512):
85
+ fh.write(chunk)
86
+
87
+ image_path = download_path
88
+
89
+ with open(image_path, "rb") as image_file:
90
+ return base64.b64encode(image_file.read()).decode('utf-8')
91
+
92
+ headers = {
93
+ "Content-Type": "application/json",
94
+ "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
95
+ }
96
+
97
+
98
+ def resize_image(image_path):
99
+ img = Image.open(image_path)
100
+ width, height = img.size
101
+ img = img.resize((int(width / 2), int(height / 2)))
102
+ new_image_path = f"resized_{image_path}"
103
+ img.save(new_image_path)
104
+ return new_image_path
105
+
106
+
107
+ class VisualQATool(Tool):
108
+ name = "visualizer"
109
+ description = "A tool that can answer questions about attached images."
110
+ inputs = {
111
+ "question": {"description": "the question to answer", "type": "text"},
112
+ "image_path": {
113
+ "description": "The path to the image on which to answer the question",
114
+ "type": "text",
115
+ },
116
+ }
117
+ output_type = "text"
118
+
119
+ client = InferenceClient("HuggingFaceM4/idefics2-8b-chatty")
120
+
121
+ def forward(self, image_path: str, question: Optional[str] = None) -> str:
122
+ add_note = False
123
+ if not question:
124
+ add_note = True
125
+ question = "Please write a detailed caption for this image."
126
+ try:
127
+ output = process_images_and_text(image_path, question, self.client)
128
+ except Exception as e:
129
+ print(e)
130
+ if "Payload Too Large" in str(e):
131
+ new_image_path = resize_image(image_path)
132
+ output = process_images_and_text(new_image_path, question, self.client)
133
+
134
+ if add_note:
135
+ output = f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
136
+
137
+ return output
138
+
139
+ class VisualQAGPT4Tool(Tool):
140
+ name = "visualizer"
141
+ description = "A tool that can answer questions about attached images."
142
+ inputs = {
143
+ "question": {"description": "the question to answer", "type": "text"},
144
+ "image_path": {
145
+ "description": "The path to the image on which to answer the question. This should be a local path to downloaded image.",
146
+ "type": "text",
147
+ },
148
+ }
149
+ output_type = "text"
150
+
151
+ def forward(self, image_path: str, question: Optional[str] = None) -> str:
152
+ add_note = False
153
+ if not question:
154
+ add_note = True
155
+ question = "Please write a detailed caption for this image."
156
+ if not isinstance(image_path, str):
157
+ raise Exception("You should provide only one string as argument to this tool!")
158
+
159
+ base64_image = encode_image(image_path)
160
+
161
+ payload = {
162
+ "model": "gpt-4o",
163
+ "messages": [
164
+ {
165
+ "role": "user",
166
+ "content": [
167
+ {
168
+ "type": "text",
169
+ "text": question
170
+ },
171
+ {
172
+ "type": "image_url",
173
+ "image_url": {
174
+ "url": f"data:image/jpeg;base64,{base64_image}"
175
+ }
176
+ }
177
+ ]
178
+ }
179
+ ],
180
+ "max_tokens": 500
181
+ }
182
+ response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
183
+ try:
184
+ output = response.json()['choices'][0]['message']['content']
185
+ except Exception:
186
+ raise Exception(f"Response format unexpected: {response.json()}")
187
+
188
+ if add_note:
189
+ output = f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
190
+
191
+ return output
tools/web_surfer.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Shamelessly stolen from Microsoft Autogen team: thanks to them for this great resource!
2
+ # https://github.com/microsoft/autogen/blob/gaia_multiagent_v01_march_1st/autogen/browser_utils.py
3
+ import os
4
+ import re
5
+ from typing import Tuple, Optional
6
+ from transformers.agents.agents import Tool
7
+ import time
8
+ from dotenv import load_dotenv
9
+ import requests
10
+ from pypdf import PdfReader
11
+ from markdownify import markdownify as md
12
+ import mimetypes
13
+ from .browser import SimpleTextBrowser
14
+
15
+ load_dotenv(override=True)
16
+
17
+ user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
18
+
19
+ browser_config = {
20
+ "viewport_size": 1024 * 5,
21
+ "downloads_folder": "coding",
22
+ "request_kwargs": {
23
+ "headers": {"User-Agent": user_agent},
24
+ "timeout": 300,
25
+ },
26
+ }
27
+
28
+ browser_config["serpapi_key"] = os.environ["SERPAPI_API_KEY"]
29
+
30
+ browser = SimpleTextBrowser(**browser_config)
31
+
32
+
33
+ # Helper functions
34
+ def _browser_state() -> Tuple[str, str]:
35
+ header = f"Address: {browser.address}\n"
36
+ if browser.page_title is not None:
37
+ header += f"Title: {browser.page_title}\n"
38
+
39
+ current_page = browser.viewport_current_page
40
+ total_pages = len(browser.viewport_pages)
41
+
42
+ address = browser.address
43
+ for i in range(len(browser.history)-2,-1,-1): # Start from the second last
44
+ if browser.history[i][0] == address:
45
+ header += f"You previously visited this page {round(time.time() - browser.history[i][1])} seconds ago.\n"
46
+ break
47
+
48
+ header += f"Viewport position: Showing page {current_page+1} of {total_pages}.\n"
49
+ return (header, browser.viewport)
50
+
51
+
52
+ class SearchInformationTool(Tool):
53
+ name="informational_web_search"
54
+ description="Perform an INFORMATIONAL web search query then return the search results."
55
+ inputs = {
56
+ "query": {
57
+ "type": "text",
58
+ "description": "The informational web search query to perform."
59
+ }
60
+ }
61
+ inputs["filter_year"]= {
62
+ "type": "text",
63
+ "description": "[Optional parameter]: filter the search results to only include pages from a specific year. For example, '2020' will only include pages from 2020. Make sure to use this parameter if you're trying to search for articles from a specific date!"
64
+ }
65
+ output_type = "text"
66
+
67
+ def forward(self, query: str, filter_year: Optional[int] = None) -> str:
68
+ browser.visit_page(f"google: {query}", filter_year=filter_year)
69
+ header, content = _browser_state()
70
+ return header.strip() + "\n=======================\n" + content
71
+
72
+
73
+ class NavigationalSearchTool(Tool):
74
+ name="navigational_web_search"
75
+ description="Perform a NAVIGATIONAL web search query then immediately navigate to the top result. Useful, for example, to navigate to a particular Wikipedia article or other known destination. Equivalent to Google's \"I'm Feeling Lucky\" button."
76
+ inputs = {"query": {"type": "text", "description": "The navigational web search query to perform."}}
77
+ output_type = "text"
78
+
79
+ def forward(self, query: str) -> str:
80
+ browser.visit_page(f"google: {query}")
81
+
82
+ # Extract the first line
83
+ m = re.search(r"\[.*?\]\((http.*?)\)", browser.page_content)
84
+ if m:
85
+ browser.visit_page(m.group(1))
86
+
87
+ # Return where we ended up
88
+ header, content = _browser_state()
89
+ return header.strip() + "\n=======================\n" + content
90
+
91
+
92
+ class VisitTool(Tool):
93
+ name="visit_page"
94
+ description="Visit a webpage at a given URL and return its text."
95
+ inputs = {"url": {"type": "text", "description": "The relative or absolute url of the webapge to visit."}}
96
+ output_type = "text"
97
+
98
+ def forward(self, url: str) -> str:
99
+ browser.visit_page(url)
100
+ header, content = _browser_state()
101
+ return header.strip() + "\n=======================\n" + content
102
+
103
+
104
+ class DownloadTool(Tool):
105
+ name="download_file"
106
+ description="""
107
+ Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".png", ".docx"]
108
+ After using this tool, for further inspection of this page you should return the download path to your manager via final_answer, and they will be able to inspect it.
109
+ DO NOT use this tool for .pdf or .txt or .htm files: for these types of files use visit_page with the file url instead."""
110
+ inputs = {"url": {"type": "text", "description": "The relative or absolute url of the file to be downloaded."}}
111
+ output_type = "text"
112
+
113
+ def forward(self, url: str) -> str:
114
+ if "arxiv" in url:
115
+ url = url.replace("abs", "pdf")
116
+ response = requests.get(url)
117
+ content_type = response.headers.get("content-type", "")
118
+ extension = mimetypes.guess_extension(content_type)
119
+ if extension and isinstance(extension, str):
120
+ new_path = f"./downloads/file{extension}"
121
+ else:
122
+ new_path = "./downloads/file.object"
123
+
124
+ with open(new_path, "wb") as f:
125
+ f.write(response.content)
126
+
127
+ if "pdf" in extension or "txt" in extension or "htm" in extension:
128
+ raise Exception("Do not use this tool for pdf or txt or html files: use visit_page instead.")
129
+
130
+ return f"File was downloaded and saved under path {new_path}."
131
+
132
+
133
+ class PageUpTool(Tool):
134
+ name="page_up"
135
+ description="Scroll the viewport UP one page-length in the current webpage and return the new viewport content."
136
+ output_type = "text"
137
+
138
+ def forward(self) -> str:
139
+ browser.page_up()
140
+ header, content = _browser_state()
141
+ return header.strip() + "\n=======================\n" + content
142
+
143
+ class ArchiveSearchTool(Tool):
144
+ name="find_archived_url"
145
+ description="Given a url, searches the Wayback Machine and returns the archived version of the url that's closest in time to the desired date."
146
+ inputs={
147
+ "url": {"type": "text", "description": "The url you need the archive for."},
148
+ "date": {"type": "text", "description": "The date that you want to find the archive for. Give this date in the format 'YYYYMMDD', for instance '27 June 2008' is written as '20080627'."}
149
+ }
150
+ output_type = "text"
151
+
152
+ def forward(self, url, date) -> str:
153
+ archive_url = f"https://archive.org/wayback/available?url={url}&timestamp={date}"
154
+ response = requests.get(archive_url).json()
155
+ try:
156
+ closest = response["archived_snapshots"]["closest"]
157
+ except:
158
+ raise Exception(f"Your url was not archived on Wayback Machine, try a different url.")
159
+ target_url = closest["url"]
160
+ browser.visit_page(target_url)
161
+ header, content = _browser_state()
162
+ return f"Web archive for url {url}, snapshot taken at date {closest['timestamp'][:8]}:\n" + header.strip() + "\n=======================\n" + content
163
+
164
+
165
+ class PageDownTool(Tool):
166
+ name="page_down"
167
+ description="Scroll the viewport DOWN one page-length in the current webpage and return the new viewport content."
168
+ output_type = "text"
169
+
170
+ def forward(self, ) -> str:
171
+ browser.page_down()
172
+ header, content = _browser_state()
173
+ return header.strip() + "\n=======================\n" + content
174
+
175
+
176
+ class FinderTool(Tool):
177
+ name="find_on_page_ctrl_f"
178
+ description="Scroll the viewport to the first occurrence of the search string. This is equivalent to Ctrl+F."
179
+ inputs = {"search_string": {"type": "text", "description": "The string to search for on the page. This search string supports wildcards like '*'" }}
180
+ output_type = "text"
181
+
182
+ def forward(self, search_string: str) -> str:
183
+ find_result = browser.find_on_page(search_string)
184
+ header, content = _browser_state()
185
+
186
+ if find_result is None:
187
+ return header.strip() + f"\n=======================\nThe search string '{search_string}' was not found on this page."
188
+ else:
189
+ return header.strip() + "\n=======================\n" + content
190
+
191
+
192
+ class FindNextTool(Tool):
193
+ name="find_next"
194
+ description="Scroll the viewport to next occurrence of the search string. This is equivalent to finding the next match in a Ctrl+F search."
195
+ inputs = {}
196
+ output_type = "text"
197
+
198
+ def forward(self, ) -> str:
199
+ find_result = browser.find_next()
200
+ header, content = _browser_state()
201
+
202
+ if find_result is None:
203
+ return header.strip() + "\n=======================\nThe search string was not found on this page."
204
+ else:
205
+ return header.strip() + "\n=======================\n" + content
utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from gradio import ChatMessage
4
+ from transformers.agents import ReactCodeAgent, agent_types
5
+ from typing import Generator
6
+
7
+ def pull_message(step_log: dict):
8
+ if step_log.get("rationale"):
9
+ yield ChatMessage(
10
+ role="assistant",
11
+ metadata={"title": "🧠 Rationale"},
12
+ content=step_log["rationale"]
13
+ )
14
+ if step_log.get("tool_call"):
15
+ used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
16
+ content = step_log["tool_call"]["tool_arguments"]
17
+ if used_code:
18
+ content = f"```py\n{content}\n```"
19
+ yield ChatMessage(
20
+ role="assistant",
21
+ metadata={"title": f"🛠️ Used tool {step_log['tool_call']['tool_name']}"},
22
+ content=content,
23
+ )
24
+ if step_log.get("observation"):
25
+ yield ChatMessage(
26
+ role="assistant",
27
+ metadata={"title": "👀 Observation"},
28
+ content=f"```\n{step_log['observation']}\n```"
29
+ )
30
+ if step_log.get("error"):
31
+ yield ChatMessage(
32
+ role="assistant",
33
+ metadata={"title": "💥 Error"},
34
+ content=str(step_log["error"]),
35
+ )
36
+
37
+ def stream_from_transformers_agent(
38
+ agent: ReactCodeAgent, prompt: str,
39
+ ) -> Generator[ChatMessage, None, ChatMessage | None]:
40
+ """Runs an agent with the given prompt and streams the messages from the agent as ChatMessages."""
41
+
42
+ class Output:
43
+ output: agent_types.AgentType | str = None
44
+
45
+ step_log = None
46
+ for step_log in agent.run(prompt, stream=True, reset=len(agent.logs) == 0): # Reset=False misbehaves if the agent has not yet been run
47
+ if isinstance(step_log, dict):
48
+ for message in pull_message(step_log):
49
+ print("message", message)
50
+ yield message
51
+
52
+ Output.output = step_log
53
+ if isinstance(Output.output, agent_types.AgentText):
54
+ yield ChatMessage(
55
+ role="assistant", content=f"**Final answer:**\n```\n{Output.output.to_string()}\n```") # type: ignore
56
+ elif isinstance(Output.output, agent_types.AgentImage):
57
+ yield ChatMessage(
58
+ role="assistant",
59
+ content={"path": Output.output.to_string(), "mime_type": "image/png"}, # type: ignore
60
+ )
61
+ elif isinstance(Output.output, agent_types.AgentAudio):
62
+ yield ChatMessage(
63
+ role="assistant",
64
+ content={"path": Output.output.to_string(), "mime_type": "audio/wav"}, # type: ignore
65
+ )
66
+ else:
67
+ return ChatMessage(role="assistant", content=Output.output)