|
from transformers.agents.tools import Tool |
|
from data import get_data |
|
|
|
class SquadRetrieverTool(Tool): |
|
name = "squad_retriever" |
|
description = """Retrieves documents from the Stanford Question Answering Dataset (SQuAD). |
|
Because this tool does not remember context from previous queries, be sure to include |
|
as many details as possible in your query. |
|
""" |
|
inputs = { |
|
"query": { |
|
"type": "string", |
|
"description": "The query. Be sure to pass this as a keyword argument and not a dictionary.", |
|
}, |
|
} |
|
output_type = "string" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.data = get_data(download=True) |
|
self.retriever = self.data.index.as_retriever() |
|
|
|
def forward(self, query: str) -> str: |
|
assert isinstance(query, str), "Your search query must be a string" |
|
|
|
responses = self.retriever.retrieve(query) |
|
|
|
if len(responses) == 0: |
|
return "No documents found for this query." |
|
return "===Document===\n" + "\n===Document===\n".join( |
|
[ |
|
f"{response.text}\nScore: {response.score}" |
|
for response in responses |
|
] |
|
) |
|
|
|
|
|
class SquadQueryTool(Tool): |
|
name = "squad_query" |
|
description = """Attempts to answer a question using the Stanford Question Answering Dataset (SQuAD). |
|
Because this tool does not remember context from previous queries, be sure to include |
|
as many details as possible in your query.""" |
|
inputs = { |
|
"query": { |
|
"type": "string", |
|
"description": "The question. Be sure to pass this as a keyword argument and not a dictionary.", |
|
}, |
|
} |
|
output_type = "string" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.data = get_data(download=True) |
|
self.query_engine = self.data.index.as_query_engine() |
|
|
|
def forward(self, query: str) -> str: |
|
assert isinstance(query, str), "Your search query must be a string" |
|
|
|
response = self.query_engine.query(query) |
|
|
|
if len(response.response) == 0: |
|
return "No answer found for this query." |
|
return "Query Response:\n\n" + "\n===Response===\n".join([response.response]) |
|
|