File size: 1,128 Bytes
60d9d3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from transformers.agents.tools import Tool
from data import Data
class SquadRetrieverTool(Tool):
name = "squad_retriever"
description = "Answers questions from the Stanford Question Answering Dataset (SQuAD)."
inputs = {
"query": {
"type": "string",
"description": "The question. This should be the literal question being asked, only modified to be informed by chat history. Be sure to pass this as a keyword argument and not a dictionary.",
},
}
output_type = "string"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.data = Data()
self.query_engine = self.data.index.as_query_engine()
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
response = self.query_engine.query(query)
# docs = self.data.index.similarity_search(query, k=3)
if len(response.response) == 0:
return "No answer found for this query."
return "Retrieved answer:\n\n" + "\n===Answer===\n".join(
[response.response]
) |