Upload folder using huggingface_hub
Browse files- app.py +2 -1
- tools/squad_tools.py +66 -0
app.py
CHANGED
@@ -3,7 +3,7 @@ from gradio import ChatMessage
|
|
3 |
from transformers import ReactCodeAgent, HfApiEngine
|
4 |
from utils import stream_from_transformers_agent
|
5 |
from prompts import SQUAD_REACT_CODE_SYSTEM_PROMPT
|
6 |
-
from tools.
|
7 |
from tools.text_to_image import TextToImageTool
|
8 |
from dotenv import load_dotenv
|
9 |
|
@@ -11,6 +11,7 @@ load_dotenv()
|
|
11 |
|
12 |
TASK_SOLVING_TOOLBOX = [
|
13 |
SquadRetrieverTool(),
|
|
|
14 |
TextToImageTool(),
|
15 |
]
|
16 |
|
|
|
3 |
from transformers import ReactCodeAgent, HfApiEngine
|
4 |
from utils import stream_from_transformers_agent
|
5 |
from prompts import SQUAD_REACT_CODE_SYSTEM_PROMPT
|
6 |
+
from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
|
7 |
from tools.text_to_image import TextToImageTool
|
8 |
from dotenv import load_dotenv
|
9 |
|
|
|
11 |
|
12 |
TASK_SOLVING_TOOLBOX = [
|
13 |
SquadRetrieverTool(),
|
14 |
+
SquadQueryTool(),
|
15 |
TextToImageTool(),
|
16 |
]
|
17 |
|
tools/squad_tools.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.agents.tools import Tool
|
2 |
+
from data import Data
|
3 |
+
|
4 |
+
|
5 |
+
class SquadRetrieverTool(Tool):
|
6 |
+
name = "squad_retriever"
|
7 |
+
description = """Retrieves documents from the Stanford Question Answering Dataset (SQuAD).
|
8 |
+
Because this tool does not remember context from previous queries, be sure to include any
|
9 |
+
relevant context in your query. Also, this tool only looks for affirmative matches, and does
|
10 |
+
not support negative queries, so only query for what you want, not what you don't want.
|
11 |
+
"""
|
12 |
+
inputs = {
|
13 |
+
"query": {
|
14 |
+
"type": "string",
|
15 |
+
"description": "The query. This could be the literal question being asked by the user, modified to be informed by your goals and chat history. Be sure to pass this as a keyword argument and not a dictionary.",
|
16 |
+
},
|
17 |
+
}
|
18 |
+
output_type = "string"
|
19 |
+
|
20 |
+
def __init__(self, **kwargs):
|
21 |
+
super().__init__(**kwargs)
|
22 |
+
self.data = Data()
|
23 |
+
self.retriever = self.data.index.as_retriever()
|
24 |
+
|
25 |
+
def forward(self, query: str) -> str:
|
26 |
+
assert isinstance(query, str), "Your search query must be a string"
|
27 |
+
|
28 |
+
responses = self.retriever.retrieve(query)
|
29 |
+
|
30 |
+
if len(responses) == 0:
|
31 |
+
return "No documents found for this query."
|
32 |
+
return "===Document===\n" + "\n===Document===\n".join(
|
33 |
+
[
|
34 |
+
f"{response.text}\nScore: {response.score}"
|
35 |
+
for response in responses
|
36 |
+
]
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
class SquadQueryTool(Tool):
|
41 |
+
name = "squad_query"
|
42 |
+
description = """Attempts to answer a question using the Stanford Question Answering Dataset (SQuAD).
|
43 |
+
Because this tool does not remember context from previous queries, be sure to include
|
44 |
+
any relevant context in your query."""
|
45 |
+
inputs = {
|
46 |
+
"query": {
|
47 |
+
"type": "string",
|
48 |
+
"description": "The question. This should be the literal question being asked, only modified to be informed by your goals and chat history. Be sure to pass this as a keyword argument and not a dictionary.",
|
49 |
+
},
|
50 |
+
}
|
51 |
+
output_type = "string"
|
52 |
+
|
53 |
+
def __init__(self, **kwargs):
|
54 |
+
super().__init__(**kwargs)
|
55 |
+
self.data = Data()
|
56 |
+
self.query_engine = self.data.index.as_query_engine()
|
57 |
+
|
58 |
+
def forward(self, query: str) -> str:
|
59 |
+
assert isinstance(query, str), "Your search query must be a string"
|
60 |
+
|
61 |
+
response = self.query_engine.query(query)
|
62 |
+
# docs = self.data.index.similarity_search(query, k=3)
|
63 |
+
|
64 |
+
if len(response.response) == 0:
|
65 |
+
return "No answer found for this query."
|
66 |
+
return "Query Response:\n\n" + "\n===Response===\n".join([response.response])
|