vonliechti commited on
Commit
f60ce50
·
verified ·
1 Parent(s): 3138d8f

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +2 -1
  2. tools/squad_tools.py +66 -0
app.py CHANGED
@@ -3,7 +3,7 @@ from gradio import ChatMessage
3
  from transformers import ReactCodeAgent, HfApiEngine
4
  from utils import stream_from_transformers_agent
5
  from prompts import SQUAD_REACT_CODE_SYSTEM_PROMPT
6
- from tools.squad_retriever import SquadRetrieverTool
7
  from tools.text_to_image import TextToImageTool
8
  from dotenv import load_dotenv
9
 
@@ -11,6 +11,7 @@ load_dotenv()
11
 
12
  TASK_SOLVING_TOOLBOX = [
13
  SquadRetrieverTool(),
 
14
  TextToImageTool(),
15
  ]
16
 
 
3
  from transformers import ReactCodeAgent, HfApiEngine
4
  from utils import stream_from_transformers_agent
5
  from prompts import SQUAD_REACT_CODE_SYSTEM_PROMPT
6
+ from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
7
  from tools.text_to_image import TextToImageTool
8
  from dotenv import load_dotenv
9
 
 
11
 
12
  TASK_SOLVING_TOOLBOX = [
13
  SquadRetrieverTool(),
14
+ SquadQueryTool(),
15
  TextToImageTool(),
16
  ]
17
 
tools/squad_tools.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.agents.tools import Tool
2
+ from data import Data
3
+
4
+
5
+ class SquadRetrieverTool(Tool):
6
+ name = "squad_retriever"
7
+ description = """Retrieves documents from the Stanford Question Answering Dataset (SQuAD).
8
+ Because this tool does not remember context from previous queries, be sure to include any
9
+ relevant context in your query. Also, this tool only looks for affirmative matches, and does
10
+ not support negative queries, so only query for what you want, not what you don't want.
11
+ """
12
+ inputs = {
13
+ "query": {
14
+ "type": "string",
15
+ "description": "The query. This could be the literal question being asked by the user, modified to be informed by your goals and chat history. Be sure to pass this as a keyword argument and not a dictionary.",
16
+ },
17
+ }
18
+ output_type = "string"
19
+
20
+ def __init__(self, **kwargs):
21
+ super().__init__(**kwargs)
22
+ self.data = Data()
23
+ self.retriever = self.data.index.as_retriever()
24
+
25
+ def forward(self, query: str) -> str:
26
+ assert isinstance(query, str), "Your search query must be a string"
27
+
28
+ responses = self.retriever.retrieve(query)
29
+
30
+ if len(responses) == 0:
31
+ return "No documents found for this query."
32
+ return "===Document===\n" + "\n===Document===\n".join(
33
+ [
34
+ f"{response.text}\nScore: {response.score}"
35
+ for response in responses
36
+ ]
37
+ )
38
+
39
+
40
+ class SquadQueryTool(Tool):
41
+ name = "squad_query"
42
+ description = """Attempts to answer a question using the Stanford Question Answering Dataset (SQuAD).
43
+ Because this tool does not remember context from previous queries, be sure to include
44
+ any relevant context in your query."""
45
+ inputs = {
46
+ "query": {
47
+ "type": "string",
48
+ "description": "The question. This should be the literal question being asked, only modified to be informed by your goals and chat history. Be sure to pass this as a keyword argument and not a dictionary.",
49
+ },
50
+ }
51
+ output_type = "string"
52
+
53
+ def __init__(self, **kwargs):
54
+ super().__init__(**kwargs)
55
+ self.data = Data()
56
+ self.query_engine = self.data.index.as_query_engine()
57
+
58
+ def forward(self, query: str) -> str:
59
+ assert isinstance(query, str), "Your search query must be a string"
60
+
61
+ response = self.query_engine.query(query)
62
+ # docs = self.data.index.similarity_search(query, k=3)
63
+
64
+ if len(response.response) == 0:
65
+ return "No answer found for this query."
66
+ return "Query Response:\n\n" + "\n===Response===\n".join([response.response])