SQuAD_Agent_Experiment / tools /squad_tools.py
vonliechti's picture
Upload folder using huggingface_hub
5f43612 verified
raw
history blame
2.29 kB
from transformers.agents.tools import Tool
from data import get_data
class SquadRetrieverTool(Tool):
name = "squad_retriever"
description = """Retrieves documents from the Stanford Question Answering Dataset (SQuAD).
Because this tool does not remember context from previous queries, be sure to include
as many details as possible in your query.
"""
inputs = {
"query": {
"type": "string",
"description": "The query. Be sure to pass this as a keyword argument and not a dictionary.",
},
}
output_type = "string"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.data = get_data(download=True)
self.retriever = self.data.index.as_retriever()
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
responses = self.retriever.retrieve(query)
if len(responses) == 0:
return "No documents found for this query."
return "===Document===\n" + "\n===Document===\n".join(
[
f"{response.text}\nScore: {response.score}"
for response in responses
]
)
class SquadQueryTool(Tool):
name = "squad_query"
description = """Attempts to answer a question using the Stanford Question Answering Dataset (SQuAD).
Because this tool does not remember context from previous queries, be sure to include
as many details as possible in your query."""
inputs = {
"query": {
"type": "string",
"description": "The question. Be sure to pass this as a keyword argument and not a dictionary.",
},
}
output_type = "string"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.data = get_data(download=True)
self.query_engine = self.data.index.as_query_engine()
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
response = self.query_engine.query(query)
if len(response.response) == 0:
return "No answer found for this query."
return "Query Response:\n\n" + "\n===Response===\n".join([response.response])