from transformers.agents.tools import Tool from data import get_data class SquadRetrieverTool(Tool): name = "squad_retriever" description = """Retrieves documents from the Stanford Question Answering Dataset (SQuAD). Because this tool does not remember context from previous queries, be sure to include any relevant context in your query. Also, this tool only looks for affirmative matches, and does not support negative queries, so only query for what you want, not what you don't want. """ inputs = { "query": { "type": "string", "description": "The query. This could be the literal question being asked by the user, modified to be informed by your goals and chat history. Be sure to pass this as a keyword argument and not a dictionary.", }, } output_type = "string" def __init__(self, **kwargs): super().__init__(**kwargs) self.data = get_data(download=True) self.retriever = self.data.index.as_retriever() def forward(self, query: str) -> str: assert isinstance(query, str), "Your search query must be a string" responses = self.retriever.retrieve(query) if len(responses) == 0: return "No documents found for this query." return "===Document===\n" + "\n===Document===\n".join( [ f"{response.text}\nScore: {response.score}" for response in responses ] ) class SquadQueryTool(Tool): name = "squad_query" description = """Attempts to answer a question using the Stanford Question Answering Dataset (SQuAD). Because this tool does not remember context from previous queries, be sure to include any relevant context in your query.""" inputs = { "query": { "type": "string", "description": "The question. This should be the literal question being asked, only modified to be informed by your goals and chat history. Be sure to pass this as a keyword argument and not a dictionary.", }, } output_type = "string" def __init__(self, **kwargs): super().__init__(**kwargs) self.data = get_data(download=True) self.query_engine = self.data.index.as_query_engine() def forward(self, query: str) -> str: assert isinstance(query, str), "Your search query must be a string" response = self.query_engine.query(query) # docs = self.data.index.similarity_search(query, k=3) if len(response.response) == 0: return "No answer found for this query." return "Query Response:\n\n" + "\n===Response===\n".join([response.response])