Upload folder using huggingface_hub
Browse files- .github/workflows/update_space.yml +28 -0
- .gitignore +169 -0
- README.md +59 -10
- app.py +88 -0
- bots.py +70 -0
- data.py +80 -0
- prompts.py +108 -0
- run.py +30 -0
- test_bots.py +14 -0
- tools/squad_retriever.py +30 -0
- tools/text_to_image.py +13 -0
- tools/visual_qa.py +191 -0
- tools/web_surfer.py +205 -0
- utils.py +67 -0
.github/workflows/update_space.yml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Run Python script
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
build:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
|
12 |
+
steps:
|
13 |
+
- name: Checkout
|
14 |
+
uses: actions/checkout@v2
|
15 |
+
|
16 |
+
- name: Set up Python
|
17 |
+
uses: actions/setup-python@v2
|
18 |
+
with:
|
19 |
+
python-version: '3.9'
|
20 |
+
|
21 |
+
- name: Install Gradio
|
22 |
+
run: python -m pip install gradio
|
23 |
+
|
24 |
+
- name: Log in to Hugging Face
|
25 |
+
run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
|
26 |
+
|
27 |
+
- name: Deploy to Spaces
|
28 |
+
run: gradio deploy
|
.gitignore
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MacOS
|
2 |
+
.DS_Store
|
3 |
+
|
4 |
+
# Data
|
5 |
+
chroma_db/
|
6 |
+
data/
|
7 |
+
|
8 |
+
# Byte-compiled / optimized / DLL files
|
9 |
+
__pycache__/
|
10 |
+
*.py[cod]
|
11 |
+
*$py.class
|
12 |
+
|
13 |
+
# C extensions
|
14 |
+
*.so
|
15 |
+
|
16 |
+
# Distribution / packaging
|
17 |
+
.Python
|
18 |
+
build/
|
19 |
+
develop-eggs/
|
20 |
+
dist/
|
21 |
+
downloads/
|
22 |
+
eggs/
|
23 |
+
.eggs/
|
24 |
+
lib/
|
25 |
+
lib64/
|
26 |
+
parts/
|
27 |
+
sdist/
|
28 |
+
var/
|
29 |
+
wheels/
|
30 |
+
share/python-wheels/
|
31 |
+
*.egg-info/
|
32 |
+
.installed.cfg
|
33 |
+
*.egg
|
34 |
+
MANIFEST
|
35 |
+
|
36 |
+
# PyInstaller
|
37 |
+
# Usually these files are written by a python script from a template
|
38 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
39 |
+
*.manifest
|
40 |
+
*.spec
|
41 |
+
|
42 |
+
# Installer logs
|
43 |
+
pip-log.txt
|
44 |
+
pip-delete-this-directory.txt
|
45 |
+
|
46 |
+
# Unit test / coverage reports
|
47 |
+
htmlcov/
|
48 |
+
.tox/
|
49 |
+
.nox/
|
50 |
+
.coverage
|
51 |
+
.coverage.*
|
52 |
+
.cache
|
53 |
+
nosetests.xml
|
54 |
+
coverage.xml
|
55 |
+
*.cover
|
56 |
+
*.py,cover
|
57 |
+
.hypothesis/
|
58 |
+
.pytest_cache/
|
59 |
+
cover/
|
60 |
+
|
61 |
+
# Translations
|
62 |
+
*.mo
|
63 |
+
*.pot
|
64 |
+
|
65 |
+
# Django stuff:
|
66 |
+
*.log
|
67 |
+
local_settings.py
|
68 |
+
db.sqlite3
|
69 |
+
db.sqlite3-journal
|
70 |
+
|
71 |
+
# Flask stuff:
|
72 |
+
instance/
|
73 |
+
.webassets-cache
|
74 |
+
|
75 |
+
# Scrapy stuff:
|
76 |
+
.scrapy
|
77 |
+
|
78 |
+
# Sphinx documentation
|
79 |
+
docs/_build/
|
80 |
+
|
81 |
+
# PyBuilder
|
82 |
+
.pybuilder/
|
83 |
+
target/
|
84 |
+
|
85 |
+
# Jupyter Notebook
|
86 |
+
.ipynb_checkpoints
|
87 |
+
|
88 |
+
# IPython
|
89 |
+
profile_default/
|
90 |
+
ipython_config.py
|
91 |
+
|
92 |
+
# pyenv
|
93 |
+
# For a library or package, you might want to ignore these files since the code is
|
94 |
+
# intended to run in multiple environments; otherwise, check them in:
|
95 |
+
# .python-version
|
96 |
+
|
97 |
+
# pipenv
|
98 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
99 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
100 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
101 |
+
# install all needed dependencies.
|
102 |
+
#Pipfile.lock
|
103 |
+
|
104 |
+
# poetry
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
106 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
107 |
+
# commonly ignored for libraries.
|
108 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
109 |
+
#poetry.lock
|
110 |
+
|
111 |
+
# pdm
|
112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
113 |
+
#pdm.lock
|
114 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
115 |
+
# in version control.
|
116 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
117 |
+
.pdm.toml
|
118 |
+
.pdm-python
|
119 |
+
.pdm-build/
|
120 |
+
|
121 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
122 |
+
__pypackages__/
|
123 |
+
|
124 |
+
# Celery stuff
|
125 |
+
celerybeat-schedule
|
126 |
+
celerybeat.pid
|
127 |
+
|
128 |
+
# SageMath parsed files
|
129 |
+
*.sage.py
|
130 |
+
|
131 |
+
# Environments
|
132 |
+
.env
|
133 |
+
.venv
|
134 |
+
env/
|
135 |
+
venv/
|
136 |
+
ENV/
|
137 |
+
env.bak/
|
138 |
+
venv.bak/
|
139 |
+
|
140 |
+
# Spyder project settings
|
141 |
+
.spyderproject
|
142 |
+
.spyproject
|
143 |
+
|
144 |
+
# Rope project settings
|
145 |
+
.ropeproject
|
146 |
+
|
147 |
+
# mkdocs documentation
|
148 |
+
/site
|
149 |
+
|
150 |
+
# mypy
|
151 |
+
.mypy_cache/
|
152 |
+
.dmypy.json
|
153 |
+
dmypy.json
|
154 |
+
|
155 |
+
# Pyre type checker
|
156 |
+
.pyre/
|
157 |
+
|
158 |
+
# pytype static type analyzer
|
159 |
+
.pytype/
|
160 |
+
|
161 |
+
# Cython debug symbols
|
162 |
+
cython_debug/
|
163 |
+
|
164 |
+
# PyCharm
|
165 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
166 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
167 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
168 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
169 |
+
#.idea/
|
README.md
CHANGED
@@ -1,14 +1,63 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji: 👁
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: pink
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.0.1
|
8 |
app_file: app.py
|
9 |
-
|
10 |
-
|
11 |
-
short_description: SQuAD Question Answering Agent
|
12 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
1 |
---
|
2 |
+
title: SQuAD_Agent_Experiment
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.44.0
|
|
|
6 |
---
|
7 |
+
# SQuAD_Agent_Experiment
|
8 |
+
|
9 |
+
## Overview
|
10 |
+
|
11 |
+
The project is built using Transformers Agents 2.0, and uses the Stanford SQuAD dataset for training. The chatbot is designed to answer questions about the dataset, while also incorporating conversational context and various tools to provide a more natural and engaging conversational experience.
|
12 |
+
|
13 |
+
## Getting Started
|
14 |
+
|
15 |
+
1. Install dependencies:
|
16 |
+
|
17 |
+
```bash
|
18 |
+
pip install -r requirements.txt
|
19 |
+
```
|
20 |
+
|
21 |
+
1. Set up required keys:
|
22 |
+
|
23 |
+
```bash
|
24 |
+
HUGGINGFACE_API_TOKEN=<your token>
|
25 |
+
```
|
26 |
+
|
27 |
+
1. Run the app:
|
28 |
+
|
29 |
+
```bash
|
30 |
+
python app.py
|
31 |
+
```
|
32 |
+
|
33 |
+
## Methods Used
|
34 |
+
|
35 |
+
1. SQuAD Dataset: The dataset used for training the chatbot is the Stanford SQuAD dataset, which contains over 100,000 questions and answers extracted from 500+ articles.
|
36 |
+
2. RAG: RAG is a technique used to improve the accuracy of chatbots by using a custom knowledge base. In this project, the Stanford SQuAD dataset is used as the knowledge base.
|
37 |
+
3. Llama 3.1: Llama 3.1 is a large language model used to generate responses to user questions. It is used in this project to generate responses to user questions, while also incorporating conversational context.
|
38 |
+
4. Transformers Agents 2.0: Transformers Agents 2.0 is a framework for building conversational AI systems. It is used in this project to build the chatbot.
|
39 |
+
5. Created a SquadRetrieverTool to integrate a fine-tuned BERT model into the agent, along with a TextToImageTool for a playful way to engage with the question-answering agent.
|
40 |
+
|
41 |
+
## Evaluation
|
42 |
+
|
43 |
+
* [Agent Reasoning Benchmark](https://github.com/aymeric-roucher/agent_reasoning_benchmark)
|
44 |
+
* [Hugging Face Blog: Open Source LLMs as Agents](https://huggingface.co/blog/open-source-llms-as-agents)
|
45 |
+
* [Benchmarking Transformers Agents](https://github.com/aymeric-roucher/agent_reasoning_benchmark/blob/main/benchmark_transformers_agents.ipynb)
|
46 |
+
|
47 |
+
## Results
|
48 |
+
|
49 |
+
TBD
|
50 |
+
|
51 |
+
## Limitations
|
52 |
+
|
53 |
+
TBD
|
54 |
+
|
55 |
+
## Future Work
|
56 |
+
|
57 |
+
TBD
|
58 |
+
|
59 |
+
## Acknowledgments
|
60 |
|
61 |
+
* [MemGPT](https://github.com/cpacker/MemGPT)
|
62 |
+
* [Stanford SQuAD](https://rajpurkar.github.io/SQuAD-explorer/)
|
63 |
+
* [GPT-4](https://openai.com/gpt-4/)
|
app.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from gradio import ChatMessage
|
3 |
+
from transformers import ReactCodeAgent, HfApiEngine
|
4 |
+
from utils import stream_from_transformers_agent
|
5 |
+
from prompts import SQUAD_REACT_CODE_SYSTEM_PROMPT
|
6 |
+
from tools.squad_retriever import SquadRetrieverTool
|
7 |
+
from tools.text_to_image import TextToImageTool
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
|
10 |
+
load_dotenv()
|
11 |
+
|
12 |
+
TASK_SOLVING_TOOLBOX = [
|
13 |
+
SquadRetrieverTool(),
|
14 |
+
TextToImageTool(),
|
15 |
+
]
|
16 |
+
|
17 |
+
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
18 |
+
# model_name = "http://localhost:1234/v1"
|
19 |
+
|
20 |
+
llm_engine = HfApiEngine(model_name)
|
21 |
+
|
22 |
+
# Initialize the agent with both tools
|
23 |
+
agent = ReactCodeAgent(
|
24 |
+
tools=TASK_SOLVING_TOOLBOX,
|
25 |
+
llm_engine=llm_engine,
|
26 |
+
system_prompt=SQUAD_REACT_CODE_SYSTEM_PROMPT,
|
27 |
+
)
|
28 |
+
|
29 |
+
def append_example_message(x: gr.SelectData, messages):
|
30 |
+
if x.value["text"] is not None:
|
31 |
+
message = x.value["text"]
|
32 |
+
if "files" in x.value:
|
33 |
+
if isinstance(x.value["files"], list):
|
34 |
+
message = "Here are the files: "
|
35 |
+
for file in x.value["files"]:
|
36 |
+
message += f"{file}, "
|
37 |
+
else:
|
38 |
+
message = x.value["files"]
|
39 |
+
messages.append(ChatMessage(role="user", content=message))
|
40 |
+
return messages
|
41 |
+
|
42 |
+
def add_message(message, messages):
|
43 |
+
messages.append(ChatMessage(role="user", content=message))
|
44 |
+
return messages
|
45 |
+
|
46 |
+
def interact_with_agent(messages):
|
47 |
+
prompt = messages[-1]['content']
|
48 |
+
for msg in stream_from_transformers_agent(agent, prompt):
|
49 |
+
messages.append(msg)
|
50 |
+
yield messages
|
51 |
+
yield messages
|
52 |
+
|
53 |
+
with gr.Blocks(fill_height=True) as demo:
|
54 |
+
chatbot = gr.Chatbot(
|
55 |
+
label="SQuAD Agent",
|
56 |
+
type="messages",
|
57 |
+
avatar_images=(
|
58 |
+
None,
|
59 |
+
"https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
|
60 |
+
),
|
61 |
+
scale=1,
|
62 |
+
bubble_full_width=False,
|
63 |
+
autoscroll=True,
|
64 |
+
show_copy_all_button=True,
|
65 |
+
show_copy_button=True,
|
66 |
+
placeholder="Enter a message",
|
67 |
+
examples=[
|
68 |
+
{
|
69 |
+
"text": "What is on top of the Notre Dame building?",
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"text": "Tell me what's on top of the Notre Dame building, and draw a picture of it.",
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"text": "Draw a picture of whatever is on top of the Notre Dame building.",
|
76 |
+
},
|
77 |
+
],
|
78 |
+
)
|
79 |
+
text_input = gr.Textbox(lines=1, label="Chat Message", scale=0)
|
80 |
+
chat_msg = text_input.submit(add_message, [text_input, chatbot], [chatbot])
|
81 |
+
bot_msg = chat_msg.then(interact_with_agent, [chatbot], [chatbot])
|
82 |
+
text_input.submit(lambda: "", None, text_input)
|
83 |
+
chatbot.example_select(append_example_message, [chatbot], [chatbot]).then(
|
84 |
+
interact_with_agent, [chatbot], [chatbot]
|
85 |
+
)
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
demo.launch()
|
bots.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from data import Data
|
2 |
+
|
3 |
+
'''
|
4 |
+
The BotWrapper class makes it so that different types of bots can be used in the same way.
|
5 |
+
This is used in the Bots class to create a list of all bots and pass them to the frontend.
|
6 |
+
'''
|
7 |
+
class BotWrapper:
|
8 |
+
def __init__(self, bot):
|
9 |
+
self.bot = bot
|
10 |
+
|
11 |
+
def chat(self, *args, **kwargs):
|
12 |
+
methods = ['chat', 'query']
|
13 |
+
for method in methods:
|
14 |
+
if hasattr(self.bot, method):
|
15 |
+
print(f"Calling {method} method")
|
16 |
+
method_to_call = getattr(self.bot, method)
|
17 |
+
return method_to_call(*args, **kwargs).response()
|
18 |
+
raise AttributeError(f"'{self.bot.__class__.__name__}' object has none of the required methods: '{methods}'")
|
19 |
+
|
20 |
+
def stream_chat(self, *args, **kwargs):
|
21 |
+
methods = ['stream_chat', 'query']
|
22 |
+
for method in methods:
|
23 |
+
if hasattr(self.bot, method):
|
24 |
+
print(f"Calling {method} method")
|
25 |
+
method_to_call = getattr(self.bot, method)
|
26 |
+
return method_to_call(*args, **kwargs).response_gen
|
27 |
+
raise AttributeError(f"'{self.bot.__class__.__name__}' object has none of the required methods: '{methods}'")
|
28 |
+
|
29 |
+
'''
|
30 |
+
The Bots class creates the bots and passes them to the frontend.
|
31 |
+
'''
|
32 |
+
class Bots:
|
33 |
+
def __init__(self):
|
34 |
+
self.data = Data()
|
35 |
+
self.data.load_data()
|
36 |
+
self.query_engine = None
|
37 |
+
self.chat_agent = None
|
38 |
+
self.all_bots = None
|
39 |
+
self.create_bots()
|
40 |
+
|
41 |
+
def create_query_engine_bot(self):
|
42 |
+
if self.query_engine is None:
|
43 |
+
self.query_engine = BotWrapper(self.data.index.as_query_engine())
|
44 |
+
return self.query_engine
|
45 |
+
|
46 |
+
def create_chat_agent(self):
|
47 |
+
if self.chat_agent is None:
|
48 |
+
from llama_index.core.memory import ChatMemoryBuffer
|
49 |
+
memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
|
50 |
+
self.chat_agent = BotWrapper(self.data.index.as_chat_engine(
|
51 |
+
chat_mode="context",
|
52 |
+
memory=memory,
|
53 |
+
context_prompt=(
|
54 |
+
"You are a chatbot, able to have normal interactions, as well as talk"
|
55 |
+
" about the questions and answers you know about."
|
56 |
+
"Here are the relevant documents for the context:\n"
|
57 |
+
"{context_str}"
|
58 |
+
"\nInstruction: Use the previous chat history, or the context above, to interact and help the user."
|
59 |
+
)
|
60 |
+
))
|
61 |
+
return self.chat_agent
|
62 |
+
|
63 |
+
def create_bots(self):
|
64 |
+
self.create_query_engine_bot()
|
65 |
+
self.create_chat_agent()
|
66 |
+
self.all_bots = [self.query_engine, self.chat_agent]
|
67 |
+
return self.all_bots
|
68 |
+
|
69 |
+
def get_bots(self):
|
70 |
+
return self.all_bots
|
data.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import chromadb
|
4 |
+
from llama_index.core import VectorStoreIndex
|
5 |
+
from llama_index.vector_stores.chroma import ChromaVectorStore
|
6 |
+
from llama_index.core import StorageContext
|
7 |
+
from llama_index.core import Document
|
8 |
+
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
|
11 |
+
load_dotenv() # Load OPENAI_API_KEY from .env (not included in repo)
|
12 |
+
|
13 |
+
class Data:
|
14 |
+
def __init__(self):
|
15 |
+
self.client = None
|
16 |
+
self.collection = None
|
17 |
+
self.index = None
|
18 |
+
self.load_data()
|
19 |
+
|
20 |
+
def load_data(self):
|
21 |
+
print("Loading data...")
|
22 |
+
with open('data/train-v1.1.json', 'r') as f:
|
23 |
+
raw_data = json.load(f)
|
24 |
+
|
25 |
+
extracted_question = []
|
26 |
+
extracted_answer = []
|
27 |
+
|
28 |
+
for data in raw_data['data']:
|
29 |
+
for par in data['paragraphs']:
|
30 |
+
for qa in par['qas']:
|
31 |
+
for ans in qa['answers']:
|
32 |
+
extracted_question.append(qa['question'])
|
33 |
+
extracted_answer.append(ans['text'])
|
34 |
+
|
35 |
+
documents = []
|
36 |
+
for i in range(len(extracted_question)):
|
37 |
+
documents.append(f"Question: {extracted_question[i]} \nAnswer: {extracted_answer[i]}")
|
38 |
+
|
39 |
+
self.documents = [Document(text=t) for t in documents]
|
40 |
+
self.extracted_question = extracted_question
|
41 |
+
self.extracted_answer = extracted_answer
|
42 |
+
|
43 |
+
print("Raw Data loaded")
|
44 |
+
|
45 |
+
if not os.path.exists("./chroma_db"):
|
46 |
+
print("Creating Chroma DB...")
|
47 |
+
# initialize client, setting path to save data
|
48 |
+
self.client = chromadb.PersistentClient(path="./chroma_db")
|
49 |
+
|
50 |
+
# create collection
|
51 |
+
self.collection = self.client.get_or_create_collection("simple_index")
|
52 |
+
|
53 |
+
# assign chroma as the vector_store to the context
|
54 |
+
vector_store = ChromaVectorStore(chroma_collection=self.collection)
|
55 |
+
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
56 |
+
|
57 |
+
# create your index
|
58 |
+
self.index = VectorStoreIndex.from_documents(
|
59 |
+
self.documents, storage_context=storage_context
|
60 |
+
)
|
61 |
+
print("Chroma DB created")
|
62 |
+
else:
|
63 |
+
print("Chroma DB already exists")
|
64 |
+
|
65 |
+
print("Loading index...")
|
66 |
+
# initialize client
|
67 |
+
self.client = chromadb.PersistentClient(path="./chroma_db")
|
68 |
+
|
69 |
+
# get collection
|
70 |
+
self.collection = self.client.get_or_create_collection("simple_index")
|
71 |
+
|
72 |
+
# assign chroma as the vector_store to the context
|
73 |
+
vector_store = ChromaVectorStore(chroma_collection=self.collection)
|
74 |
+
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
75 |
+
|
76 |
+
# load your index from stored vectors
|
77 |
+
self.index = VectorStoreIndex.from_vector_store(
|
78 |
+
vector_store, storage_context=storage_context
|
79 |
+
)
|
80 |
+
print("Index loaded")
|
prompts.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SQUAD_REACT_CODE_SYSTEM_PROMPT = """You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can.
|
2 |
+
To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code.
|
3 |
+
To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences.
|
4 |
+
|
5 |
+
Your most important tool is the `squad_retriever` tool,which can answer questions from the Stanford Question Answering Dataset (SQuAD).
|
6 |
+
Not all questions will require the `squad_retriever` tool, but whenever you need to answer a question, you should start with this tool first, and then refine your answer only as needed to align with the question and chat history.
|
7 |
+
|
8 |
+
At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use.
|
9 |
+
Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '<end_action>' sequence.
|
10 |
+
During each intermediate step, you can use 'print()' to save whatever important information you will then need.
|
11 |
+
These print outputs will then appear in the 'Observation:' field, which will be available as input for the next step.
|
12 |
+
In the end you have to return a final answer using the `final_answer` tool.
|
13 |
+
|
14 |
+
Here are a few examples using notional tools:
|
15 |
+
---
|
16 |
+
Task: "Generate an image of the oldest person in this document."
|
17 |
+
|
18 |
+
Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer.
|
19 |
+
Code:
|
20 |
+
```py
|
21 |
+
answer = document_qa(document=document, question="Who is the oldest person mentioned?")
|
22 |
+
print(answer)
|
23 |
+
```<end_action>
|
24 |
+
Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland."
|
25 |
+
|
26 |
+
Thought: I will now generate an image showcasing the oldest person.
|
27 |
+
Code:
|
28 |
+
```py
|
29 |
+
image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.")
|
30 |
+
final_answer(image)
|
31 |
+
```<end_action>
|
32 |
+
|
33 |
+
---
|
34 |
+
Task: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
35 |
+
|
36 |
+
Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool
|
37 |
+
Code:
|
38 |
+
```py
|
39 |
+
result = 5 + 3 + 1294.678
|
40 |
+
final_answer(result)
|
41 |
+
```<end_action>
|
42 |
+
|
43 |
+
---
|
44 |
+
Task: "Which city has the highest population: Guangzhou or Shanghai?"
|
45 |
+
|
46 |
+
Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities.
|
47 |
+
Code:
|
48 |
+
```py
|
49 |
+
population_guangzhou = search("Guangzhou population")
|
50 |
+
print("Population Guangzhou:", population_guangzhou)
|
51 |
+
population_shanghai = search("Shanghai population")
|
52 |
+
print("Population Shanghai:", population_shanghai)
|
53 |
+
```<end_action>
|
54 |
+
Observation:
|
55 |
+
Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.']
|
56 |
+
Population Shanghai: '26 million (2019)'
|
57 |
+
|
58 |
+
Thought: Now I know that Shanghai has the highest population.
|
59 |
+
Code:
|
60 |
+
```py
|
61 |
+
final_answer("Shanghai")
|
62 |
+
```<end_action>
|
63 |
+
|
64 |
+
---
|
65 |
+
Task: "What is the current age of the pope, raised to the power 0.36?"
|
66 |
+
|
67 |
+
Thought: I will use the tool `wiki` to get the age of the pope, then raise it to the power 0.36.
|
68 |
+
Code:
|
69 |
+
```py
|
70 |
+
pope_age = wiki(query="current pope age")
|
71 |
+
print("Pope age:", pope_age)
|
72 |
+
```<end_action>
|
73 |
+
Observation:
|
74 |
+
Pope age: "The pope Francis is currently 85 years old."
|
75 |
+
|
76 |
+
Thought: I know that the pope is 85 years old. Let's compute the result using python code.
|
77 |
+
Code:
|
78 |
+
```py
|
79 |
+
pope_current_age = 85 ** 0.36
|
80 |
+
final_answer(pope_current_age)
|
81 |
+
```<end_action>
|
82 |
+
|
83 |
+
Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have access to those tools (and no other tool):
|
84 |
+
|
85 |
+
<<tool_descriptions>>
|
86 |
+
|
87 |
+
<<managed_agents_descriptions>>
|
88 |
+
|
89 |
+
Here are the rules you should always follow to solve your task:
|
90 |
+
1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```<end_action>' sequence, else you will fail.
|
91 |
+
2. Use only variables that you have defined!
|
92 |
+
3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = wiki({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = wiki(query="What is the place where James Bond lives?")'.
|
93 |
+
4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block.
|
94 |
+
5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters.
|
95 |
+
6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'.
|
96 |
+
7. Never create any notional variables in our code, as having these in your logs might derail you from the true variables.
|
97 |
+
8. You can use imports in your code, but only from the following list of modules: <<authorized_imports>>
|
98 |
+
9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist.
|
99 |
+
10. Don't give up! You're in charge of solving the task, not providing directions to solve it.
|
100 |
+
11. Only use the tools that have been provided to you.
|
101 |
+
12. Only generate an image when asked to do so.
|
102 |
+
13. If the task questions the rationale of your previous answers, explain your rationale for the previous answers and attempt to correct any mistakes in your previous answers.
|
103 |
+
|
104 |
+
As for your identity, your name is Agent SQuAD, you are an AI Agent, an expert guide to all questions and answers in the Stanford Question Answering Dataset (SQuAD), and you are SQuADtacular!
|
105 |
+
Do not use the squad_retriever tool to answer questions about yourself, such as "what is your name" or "what are you".
|
106 |
+
|
107 |
+
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.
|
108 |
+
"""
|
run.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from gradio import ChatMessage
|
3 |
+
from transformers import load_tool, ReactCodeAgent, HfEngine # type: ignore
|
4 |
+
from utils import stream_from_transformers_agent
|
5 |
+
|
6 |
+
# Import tool from Hub
|
7 |
+
image_generation_tool = load_tool("m-ric/text-to-image")
|
8 |
+
|
9 |
+
llm_engine = HfEngine("meta-llama/Meta-Llama-3-70B-Instruct")
|
10 |
+
# Initialize the agent with both tools
|
11 |
+
agent = ReactCodeAgent(tools=[image_generation_tool], llm_engine=llm_engine)
|
12 |
+
|
13 |
+
def interact_with_agent(prompt, messages):
|
14 |
+
messages.append(ChatMessage(role="user", content=prompt))
|
15 |
+
yield messages
|
16 |
+
for msg in stream_from_transformers_agent(agent, prompt):
|
17 |
+
messages.append(msg)
|
18 |
+
yield messages
|
19 |
+
yield messages
|
20 |
+
|
21 |
+
with gr.Blocks() as demo:
|
22 |
+
stored_message = gr.State([])
|
23 |
+
chatbot = gr.Chatbot(label="Agent",
|
24 |
+
type="messages",
|
25 |
+
avatar_images=(None, "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"))
|
26 |
+
text_input = gr.Textbox(lines=1, label="Chat Message")
|
27 |
+
text_input.submit(lambda s: (s, ""), [text_input], [stored_message, text_input]).then(interact_with_agent, [stored_message, chatbot], [chatbot])
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
demo.launch()
|
test_bots.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from deepeval import assert_test
|
3 |
+
from deepeval.metrics import AnswerRelevancyMetric
|
4 |
+
from deepeval.test_case import LLMTestCase
|
5 |
+
|
6 |
+
def test_case():
|
7 |
+
answer_relevancy_metric = AnswerRelevancyMetric(threshold=0.5)
|
8 |
+
test_case = LLMTestCase(
|
9 |
+
input="What if these shoes don't fit?",
|
10 |
+
# Replace this with the actual output from your LLM application
|
11 |
+
actual_output="We offer a 30-day full refund at no extra costs.",
|
12 |
+
retrieval_context=["All customers are eligible for a 30 day full refund at no extra costs."]
|
13 |
+
)
|
14 |
+
assert_test(test_case, [answer_relevancy_metric])
|
tools/squad_retriever.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.agents.tools import Tool
|
2 |
+
from data import Data
|
3 |
+
|
4 |
+
class SquadRetrieverTool(Tool):
|
5 |
+
name = "squad_retriever"
|
6 |
+
description = "Answers questions from the Stanford Question Answering Dataset (SQuAD)."
|
7 |
+
inputs = {
|
8 |
+
"query": {
|
9 |
+
"type": "string",
|
10 |
+
"description": "The question. This should be the literal question being asked, only modified to be informed by chat history. Be sure to pass this as a keyword argument and not a dictionary.",
|
11 |
+
},
|
12 |
+
}
|
13 |
+
output_type = "string"
|
14 |
+
|
15 |
+
def __init__(self, **kwargs):
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
self.data = Data()
|
18 |
+
self.query_engine = self.data.index.as_query_engine()
|
19 |
+
|
20 |
+
def forward(self, query: str) -> str:
|
21 |
+
assert isinstance(query, str), "Your search query must be a string"
|
22 |
+
|
23 |
+
response = self.query_engine.query(query)
|
24 |
+
# docs = self.data.index.similarity_search(query, k=3)
|
25 |
+
|
26 |
+
if len(response.response) == 0:
|
27 |
+
return "No answer found for this query."
|
28 |
+
return "Retrieved answer:\n\n" + "\n===Answer===\n".join(
|
29 |
+
[response.response]
|
30 |
+
)
|
tools/text_to_image.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.agents.tools import Tool
|
2 |
+
from huggingface_hub import InferenceClient
|
3 |
+
|
4 |
+
class TextToImageTool(Tool):
|
5 |
+
description = "This is a tool that creates an image according to a prompt, which is a text description."
|
6 |
+
name = "image_generator"
|
7 |
+
inputs = {"prompt": {"type": "string", "description": "The image generator prompt. Don't hesitate to add details in the prompt to make the image look better, like 'high-res, photorealistic', etc."}}
|
8 |
+
output_type = "image"
|
9 |
+
model_sdxl = "stabilityai/stable-diffusion-xl-base-1.0"
|
10 |
+
client = InferenceClient(model_sdxl)
|
11 |
+
|
12 |
+
def forward(self, prompt):
|
13 |
+
return self.client.text_to_image(prompt)
|
tools/visual_qa.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import base64
|
3 |
+
from io import BytesIO
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import requests
|
7 |
+
from typing import Optional
|
8 |
+
from huggingface_hub import InferenceClient
|
9 |
+
from transformers import AutoProcessor, Tool
|
10 |
+
import uuid
|
11 |
+
import mimetypes
|
12 |
+
from dotenv import load_dotenv
|
13 |
+
|
14 |
+
load_dotenv(override=True)
|
15 |
+
|
16 |
+
idefics_processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
|
17 |
+
|
18 |
+
def process_images_and_text(image_path, query, client):
|
19 |
+
messages = [
|
20 |
+
{
|
21 |
+
"role": "user", "content": [
|
22 |
+
{"type": "image"},
|
23 |
+
{"type": "text", "text": query},
|
24 |
+
]
|
25 |
+
},
|
26 |
+
]
|
27 |
+
|
28 |
+
prompt_with_template = idefics_processor.apply_chat_template(messages, add_generation_prompt=True)
|
29 |
+
|
30 |
+
# load images from local directory
|
31 |
+
|
32 |
+
# encode images to strings which can be sent to the endpoint
|
33 |
+
def encode_local_image(image_path):
|
34 |
+
# load image
|
35 |
+
image = Image.open(image_path).convert('RGB')
|
36 |
+
|
37 |
+
# Convert the image to a base64 string
|
38 |
+
buffer = BytesIO()
|
39 |
+
image.save(buffer, format="JPEG") # Use the appropriate format (e.g., JPEG, PNG)
|
40 |
+
base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
41 |
+
|
42 |
+
# add string formatting required by the endpoint
|
43 |
+
image_string = f"data:image/jpeg;base64,{base64_image}"
|
44 |
+
|
45 |
+
return image_string
|
46 |
+
|
47 |
+
|
48 |
+
image_string = encode_local_image(image_path)
|
49 |
+
prompt_with_images = prompt_with_template.replace("<image>", " ").format(image_string)
|
50 |
+
|
51 |
+
|
52 |
+
payload = {
|
53 |
+
"inputs": prompt_with_images,
|
54 |
+
"parameters": {
|
55 |
+
"return_full_text": False,
|
56 |
+
"max_new_tokens": 200,
|
57 |
+
}
|
58 |
+
}
|
59 |
+
|
60 |
+
return json.loads(client.post(json=payload).decode())[0]
|
61 |
+
|
62 |
+
# Function to encode the image
|
63 |
+
def encode_image(image_path):
|
64 |
+
if image_path.startswith("http"):
|
65 |
+
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
|
66 |
+
request_kwargs = {
|
67 |
+
"headers": {"User-Agent": user_agent},
|
68 |
+
"stream": True,
|
69 |
+
}
|
70 |
+
|
71 |
+
# Send a HTTP request to the URL
|
72 |
+
response = requests.get(image_path, **request_kwargs)
|
73 |
+
response.raise_for_status()
|
74 |
+
content_type = response.headers.get("content-type", "")
|
75 |
+
|
76 |
+
extension = mimetypes.guess_extension(content_type)
|
77 |
+
if extension is None:
|
78 |
+
extension = ".download"
|
79 |
+
|
80 |
+
fname = str(uuid.uuid4()) + extension
|
81 |
+
download_path = os.path.abspath(os.path.join("downloads", fname))
|
82 |
+
|
83 |
+
with open(download_path, "wb") as fh:
|
84 |
+
for chunk in response.iter_content(chunk_size=512):
|
85 |
+
fh.write(chunk)
|
86 |
+
|
87 |
+
image_path = download_path
|
88 |
+
|
89 |
+
with open(image_path, "rb") as image_file:
|
90 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|
91 |
+
|
92 |
+
headers = {
|
93 |
+
"Content-Type": "application/json",
|
94 |
+
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
|
95 |
+
}
|
96 |
+
|
97 |
+
|
98 |
+
def resize_image(image_path):
|
99 |
+
img = Image.open(image_path)
|
100 |
+
width, height = img.size
|
101 |
+
img = img.resize((int(width / 2), int(height / 2)))
|
102 |
+
new_image_path = f"resized_{image_path}"
|
103 |
+
img.save(new_image_path)
|
104 |
+
return new_image_path
|
105 |
+
|
106 |
+
|
107 |
+
class VisualQATool(Tool):
|
108 |
+
name = "visualizer"
|
109 |
+
description = "A tool that can answer questions about attached images."
|
110 |
+
inputs = {
|
111 |
+
"question": {"description": "the question to answer", "type": "text"},
|
112 |
+
"image_path": {
|
113 |
+
"description": "The path to the image on which to answer the question",
|
114 |
+
"type": "text",
|
115 |
+
},
|
116 |
+
}
|
117 |
+
output_type = "text"
|
118 |
+
|
119 |
+
client = InferenceClient("HuggingFaceM4/idefics2-8b-chatty")
|
120 |
+
|
121 |
+
def forward(self, image_path: str, question: Optional[str] = None) -> str:
|
122 |
+
add_note = False
|
123 |
+
if not question:
|
124 |
+
add_note = True
|
125 |
+
question = "Please write a detailed caption for this image."
|
126 |
+
try:
|
127 |
+
output = process_images_and_text(image_path, question, self.client)
|
128 |
+
except Exception as e:
|
129 |
+
print(e)
|
130 |
+
if "Payload Too Large" in str(e):
|
131 |
+
new_image_path = resize_image(image_path)
|
132 |
+
output = process_images_and_text(new_image_path, question, self.client)
|
133 |
+
|
134 |
+
if add_note:
|
135 |
+
output = f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
|
136 |
+
|
137 |
+
return output
|
138 |
+
|
139 |
+
class VisualQAGPT4Tool(Tool):
|
140 |
+
name = "visualizer"
|
141 |
+
description = "A tool that can answer questions about attached images."
|
142 |
+
inputs = {
|
143 |
+
"question": {"description": "the question to answer", "type": "text"},
|
144 |
+
"image_path": {
|
145 |
+
"description": "The path to the image on which to answer the question. This should be a local path to downloaded image.",
|
146 |
+
"type": "text",
|
147 |
+
},
|
148 |
+
}
|
149 |
+
output_type = "text"
|
150 |
+
|
151 |
+
def forward(self, image_path: str, question: Optional[str] = None) -> str:
|
152 |
+
add_note = False
|
153 |
+
if not question:
|
154 |
+
add_note = True
|
155 |
+
question = "Please write a detailed caption for this image."
|
156 |
+
if not isinstance(image_path, str):
|
157 |
+
raise Exception("You should provide only one string as argument to this tool!")
|
158 |
+
|
159 |
+
base64_image = encode_image(image_path)
|
160 |
+
|
161 |
+
payload = {
|
162 |
+
"model": "gpt-4o",
|
163 |
+
"messages": [
|
164 |
+
{
|
165 |
+
"role": "user",
|
166 |
+
"content": [
|
167 |
+
{
|
168 |
+
"type": "text",
|
169 |
+
"text": question
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"type": "image_url",
|
173 |
+
"image_url": {
|
174 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
175 |
+
}
|
176 |
+
}
|
177 |
+
]
|
178 |
+
}
|
179 |
+
],
|
180 |
+
"max_tokens": 500
|
181 |
+
}
|
182 |
+
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
|
183 |
+
try:
|
184 |
+
output = response.json()['choices'][0]['message']['content']
|
185 |
+
except Exception:
|
186 |
+
raise Exception(f"Response format unexpected: {response.json()}")
|
187 |
+
|
188 |
+
if add_note:
|
189 |
+
output = f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
|
190 |
+
|
191 |
+
return output
|
tools/web_surfer.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Shamelessly stolen from Microsoft Autogen team: thanks to them for this great resource!
|
2 |
+
# https://github.com/microsoft/autogen/blob/gaia_multiagent_v01_march_1st/autogen/browser_utils.py
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
from typing import Tuple, Optional
|
6 |
+
from transformers.agents.agents import Tool
|
7 |
+
import time
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
import requests
|
10 |
+
from pypdf import PdfReader
|
11 |
+
from markdownify import markdownify as md
|
12 |
+
import mimetypes
|
13 |
+
from .browser import SimpleTextBrowser
|
14 |
+
|
15 |
+
load_dotenv(override=True)
|
16 |
+
|
17 |
+
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
|
18 |
+
|
19 |
+
browser_config = {
|
20 |
+
"viewport_size": 1024 * 5,
|
21 |
+
"downloads_folder": "coding",
|
22 |
+
"request_kwargs": {
|
23 |
+
"headers": {"User-Agent": user_agent},
|
24 |
+
"timeout": 300,
|
25 |
+
},
|
26 |
+
}
|
27 |
+
|
28 |
+
browser_config["serpapi_key"] = os.environ["SERPAPI_API_KEY"]
|
29 |
+
|
30 |
+
browser = SimpleTextBrowser(**browser_config)
|
31 |
+
|
32 |
+
|
33 |
+
# Helper functions
|
34 |
+
def _browser_state() -> Tuple[str, str]:
|
35 |
+
header = f"Address: {browser.address}\n"
|
36 |
+
if browser.page_title is not None:
|
37 |
+
header += f"Title: {browser.page_title}\n"
|
38 |
+
|
39 |
+
current_page = browser.viewport_current_page
|
40 |
+
total_pages = len(browser.viewport_pages)
|
41 |
+
|
42 |
+
address = browser.address
|
43 |
+
for i in range(len(browser.history)-2,-1,-1): # Start from the second last
|
44 |
+
if browser.history[i][0] == address:
|
45 |
+
header += f"You previously visited this page {round(time.time() - browser.history[i][1])} seconds ago.\n"
|
46 |
+
break
|
47 |
+
|
48 |
+
header += f"Viewport position: Showing page {current_page+1} of {total_pages}.\n"
|
49 |
+
return (header, browser.viewport)
|
50 |
+
|
51 |
+
|
52 |
+
class SearchInformationTool(Tool):
|
53 |
+
name="informational_web_search"
|
54 |
+
description="Perform an INFORMATIONAL web search query then return the search results."
|
55 |
+
inputs = {
|
56 |
+
"query": {
|
57 |
+
"type": "text",
|
58 |
+
"description": "The informational web search query to perform."
|
59 |
+
}
|
60 |
+
}
|
61 |
+
inputs["filter_year"]= {
|
62 |
+
"type": "text",
|
63 |
+
"description": "[Optional parameter]: filter the search results to only include pages from a specific year. For example, '2020' will only include pages from 2020. Make sure to use this parameter if you're trying to search for articles from a specific date!"
|
64 |
+
}
|
65 |
+
output_type = "text"
|
66 |
+
|
67 |
+
def forward(self, query: str, filter_year: Optional[int] = None) -> str:
|
68 |
+
browser.visit_page(f"google: {query}", filter_year=filter_year)
|
69 |
+
header, content = _browser_state()
|
70 |
+
return header.strip() + "\n=======================\n" + content
|
71 |
+
|
72 |
+
|
73 |
+
class NavigationalSearchTool(Tool):
|
74 |
+
name="navigational_web_search"
|
75 |
+
description="Perform a NAVIGATIONAL web search query then immediately navigate to the top result. Useful, for example, to navigate to a particular Wikipedia article or other known destination. Equivalent to Google's \"I'm Feeling Lucky\" button."
|
76 |
+
inputs = {"query": {"type": "text", "description": "The navigational web search query to perform."}}
|
77 |
+
output_type = "text"
|
78 |
+
|
79 |
+
def forward(self, query: str) -> str:
|
80 |
+
browser.visit_page(f"google: {query}")
|
81 |
+
|
82 |
+
# Extract the first line
|
83 |
+
m = re.search(r"\[.*?\]\((http.*?)\)", browser.page_content)
|
84 |
+
if m:
|
85 |
+
browser.visit_page(m.group(1))
|
86 |
+
|
87 |
+
# Return where we ended up
|
88 |
+
header, content = _browser_state()
|
89 |
+
return header.strip() + "\n=======================\n" + content
|
90 |
+
|
91 |
+
|
92 |
+
class VisitTool(Tool):
|
93 |
+
name="visit_page"
|
94 |
+
description="Visit a webpage at a given URL and return its text."
|
95 |
+
inputs = {"url": {"type": "text", "description": "The relative or absolute url of the webapge to visit."}}
|
96 |
+
output_type = "text"
|
97 |
+
|
98 |
+
def forward(self, url: str) -> str:
|
99 |
+
browser.visit_page(url)
|
100 |
+
header, content = _browser_state()
|
101 |
+
return header.strip() + "\n=======================\n" + content
|
102 |
+
|
103 |
+
|
104 |
+
class DownloadTool(Tool):
|
105 |
+
name="download_file"
|
106 |
+
description="""
|
107 |
+
Download a file at a given URL. The file should be of this format: [".xlsx", ".pptx", ".wav", ".mp3", ".png", ".docx"]
|
108 |
+
After using this tool, for further inspection of this page you should return the download path to your manager via final_answer, and they will be able to inspect it.
|
109 |
+
DO NOT use this tool for .pdf or .txt or .htm files: for these types of files use visit_page with the file url instead."""
|
110 |
+
inputs = {"url": {"type": "text", "description": "The relative or absolute url of the file to be downloaded."}}
|
111 |
+
output_type = "text"
|
112 |
+
|
113 |
+
def forward(self, url: str) -> str:
|
114 |
+
if "arxiv" in url:
|
115 |
+
url = url.replace("abs", "pdf")
|
116 |
+
response = requests.get(url)
|
117 |
+
content_type = response.headers.get("content-type", "")
|
118 |
+
extension = mimetypes.guess_extension(content_type)
|
119 |
+
if extension and isinstance(extension, str):
|
120 |
+
new_path = f"./downloads/file{extension}"
|
121 |
+
else:
|
122 |
+
new_path = "./downloads/file.object"
|
123 |
+
|
124 |
+
with open(new_path, "wb") as f:
|
125 |
+
f.write(response.content)
|
126 |
+
|
127 |
+
if "pdf" in extension or "txt" in extension or "htm" in extension:
|
128 |
+
raise Exception("Do not use this tool for pdf or txt or html files: use visit_page instead.")
|
129 |
+
|
130 |
+
return f"File was downloaded and saved under path {new_path}."
|
131 |
+
|
132 |
+
|
133 |
+
class PageUpTool(Tool):
|
134 |
+
name="page_up"
|
135 |
+
description="Scroll the viewport UP one page-length in the current webpage and return the new viewport content."
|
136 |
+
output_type = "text"
|
137 |
+
|
138 |
+
def forward(self) -> str:
|
139 |
+
browser.page_up()
|
140 |
+
header, content = _browser_state()
|
141 |
+
return header.strip() + "\n=======================\n" + content
|
142 |
+
|
143 |
+
class ArchiveSearchTool(Tool):
|
144 |
+
name="find_archived_url"
|
145 |
+
description="Given a url, searches the Wayback Machine and returns the archived version of the url that's closest in time to the desired date."
|
146 |
+
inputs={
|
147 |
+
"url": {"type": "text", "description": "The url you need the archive for."},
|
148 |
+
"date": {"type": "text", "description": "The date that you want to find the archive for. Give this date in the format 'YYYYMMDD', for instance '27 June 2008' is written as '20080627'."}
|
149 |
+
}
|
150 |
+
output_type = "text"
|
151 |
+
|
152 |
+
def forward(self, url, date) -> str:
|
153 |
+
archive_url = f"https://archive.org/wayback/available?url={url}×tamp={date}"
|
154 |
+
response = requests.get(archive_url).json()
|
155 |
+
try:
|
156 |
+
closest = response["archived_snapshots"]["closest"]
|
157 |
+
except:
|
158 |
+
raise Exception(f"Your url was not archived on Wayback Machine, try a different url.")
|
159 |
+
target_url = closest["url"]
|
160 |
+
browser.visit_page(target_url)
|
161 |
+
header, content = _browser_state()
|
162 |
+
return f"Web archive for url {url}, snapshot taken at date {closest['timestamp'][:8]}:\n" + header.strip() + "\n=======================\n" + content
|
163 |
+
|
164 |
+
|
165 |
+
class PageDownTool(Tool):
|
166 |
+
name="page_down"
|
167 |
+
description="Scroll the viewport DOWN one page-length in the current webpage and return the new viewport content."
|
168 |
+
output_type = "text"
|
169 |
+
|
170 |
+
def forward(self, ) -> str:
|
171 |
+
browser.page_down()
|
172 |
+
header, content = _browser_state()
|
173 |
+
return header.strip() + "\n=======================\n" + content
|
174 |
+
|
175 |
+
|
176 |
+
class FinderTool(Tool):
|
177 |
+
name="find_on_page_ctrl_f"
|
178 |
+
description="Scroll the viewport to the first occurrence of the search string. This is equivalent to Ctrl+F."
|
179 |
+
inputs = {"search_string": {"type": "text", "description": "The string to search for on the page. This search string supports wildcards like '*'" }}
|
180 |
+
output_type = "text"
|
181 |
+
|
182 |
+
def forward(self, search_string: str) -> str:
|
183 |
+
find_result = browser.find_on_page(search_string)
|
184 |
+
header, content = _browser_state()
|
185 |
+
|
186 |
+
if find_result is None:
|
187 |
+
return header.strip() + f"\n=======================\nThe search string '{search_string}' was not found on this page."
|
188 |
+
else:
|
189 |
+
return header.strip() + "\n=======================\n" + content
|
190 |
+
|
191 |
+
|
192 |
+
class FindNextTool(Tool):
|
193 |
+
name="find_next"
|
194 |
+
description="Scroll the viewport to next occurrence of the search string. This is equivalent to finding the next match in a Ctrl+F search."
|
195 |
+
inputs = {}
|
196 |
+
output_type = "text"
|
197 |
+
|
198 |
+
def forward(self, ) -> str:
|
199 |
+
find_result = browser.find_next()
|
200 |
+
header, content = _browser_state()
|
201 |
+
|
202 |
+
if find_result is None:
|
203 |
+
return header.strip() + "\n=======================\nThe search string was not found on this page."
|
204 |
+
else:
|
205 |
+
return header.strip() + "\n=======================\n" + content
|
utils.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from gradio import ChatMessage
|
4 |
+
from transformers.agents import ReactCodeAgent, agent_types
|
5 |
+
from typing import Generator
|
6 |
+
|
7 |
+
def pull_message(step_log: dict):
|
8 |
+
if step_log.get("rationale"):
|
9 |
+
yield ChatMessage(
|
10 |
+
role="assistant",
|
11 |
+
metadata={"title": "🧠 Rationale"},
|
12 |
+
content=step_log["rationale"]
|
13 |
+
)
|
14 |
+
if step_log.get("tool_call"):
|
15 |
+
used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
|
16 |
+
content = step_log["tool_call"]["tool_arguments"]
|
17 |
+
if used_code:
|
18 |
+
content = f"```py\n{content}\n```"
|
19 |
+
yield ChatMessage(
|
20 |
+
role="assistant",
|
21 |
+
metadata={"title": f"🛠️ Used tool {step_log['tool_call']['tool_name']}"},
|
22 |
+
content=content,
|
23 |
+
)
|
24 |
+
if step_log.get("observation"):
|
25 |
+
yield ChatMessage(
|
26 |
+
role="assistant",
|
27 |
+
metadata={"title": "👀 Observation"},
|
28 |
+
content=f"```\n{step_log['observation']}\n```"
|
29 |
+
)
|
30 |
+
if step_log.get("error"):
|
31 |
+
yield ChatMessage(
|
32 |
+
role="assistant",
|
33 |
+
metadata={"title": "💥 Error"},
|
34 |
+
content=str(step_log["error"]),
|
35 |
+
)
|
36 |
+
|
37 |
+
def stream_from_transformers_agent(
|
38 |
+
agent: ReactCodeAgent, prompt: str,
|
39 |
+
) -> Generator[ChatMessage, None, ChatMessage | None]:
|
40 |
+
"""Runs an agent with the given prompt and streams the messages from the agent as ChatMessages."""
|
41 |
+
|
42 |
+
class Output:
|
43 |
+
output: agent_types.AgentType | str = None
|
44 |
+
|
45 |
+
step_log = None
|
46 |
+
for step_log in agent.run(prompt, stream=True, reset=len(agent.logs) == 0): # Reset=False misbehaves if the agent has not yet been run
|
47 |
+
if isinstance(step_log, dict):
|
48 |
+
for message in pull_message(step_log):
|
49 |
+
print("message", message)
|
50 |
+
yield message
|
51 |
+
|
52 |
+
Output.output = step_log
|
53 |
+
if isinstance(Output.output, agent_types.AgentText):
|
54 |
+
yield ChatMessage(
|
55 |
+
role="assistant", content=f"**Final answer:**\n```\n{Output.output.to_string()}\n```") # type: ignore
|
56 |
+
elif isinstance(Output.output, agent_types.AgentImage):
|
57 |
+
yield ChatMessage(
|
58 |
+
role="assistant",
|
59 |
+
content={"path": Output.output.to_string(), "mime_type": "image/png"}, # type: ignore
|
60 |
+
)
|
61 |
+
elif isinstance(Output.output, agent_types.AgentAudio):
|
62 |
+
yield ChatMessage(
|
63 |
+
role="assistant",
|
64 |
+
content={"path": Output.output.to_string(), "mime_type": "audio/wav"}, # type: ignore
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
return ChatMessage(role="assistant", content=Output.output)
|