File size: 2,286 Bytes
f60ce50 59b7329 f60ce50 5f43612 f60ce50 5f43612 f60ce50 59b7329 f60ce50 5f43612 f60ce50 5f43612 f60ce50 59b7329 f60ce50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from transformers.agents.tools import Tool
from data import get_data
class SquadRetrieverTool(Tool):
name = "squad_retriever"
description = """Retrieves documents from the Stanford Question Answering Dataset (SQuAD).
Because this tool does not remember context from previous queries, be sure to include
as many details as possible in your query.
"""
inputs = {
"query": {
"type": "string",
"description": "The query. Be sure to pass this as a keyword argument and not a dictionary.",
},
}
output_type = "string"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.data = get_data(download=True)
self.retriever = self.data.index.as_retriever()
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
responses = self.retriever.retrieve(query)
if len(responses) == 0:
return "No documents found for this query."
return "===Document===\n" + "\n===Document===\n".join(
[
f"{response.text}\nScore: {response.score}"
for response in responses
]
)
class SquadQueryTool(Tool):
name = "squad_query"
description = """Attempts to answer a question using the Stanford Question Answering Dataset (SQuAD).
Because this tool does not remember context from previous queries, be sure to include
as many details as possible in your query."""
inputs = {
"query": {
"type": "string",
"description": "The question. Be sure to pass this as a keyword argument and not a dictionary.",
},
}
output_type = "string"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.data = get_data(download=True)
self.query_engine = self.data.index.as_query_engine()
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
response = self.query_engine.query(query)
if len(response.response) == 0:
return "No answer found for this query."
return "Query Response:\n\n" + "\n===Response===\n".join([response.response])
|