File size: 2,286 Bytes
f60ce50
59b7329
f60ce50
 
 
 
5f43612
 
f60ce50
 
 
 
5f43612
f60ce50
 
 
 
 
 
59b7329
f60ce50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f43612
f60ce50
 
 
5f43612
f60ce50
 
 
 
 
 
59b7329
f60ce50
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from transformers.agents.tools import Tool
from data import get_data

class SquadRetrieverTool(Tool):
    name = "squad_retriever"
    description = """Retrieves documents from the Stanford Question Answering Dataset (SQuAD). 
        Because this tool does not remember context from previous queries, be sure to include 
        as many details as possible in your query. 
        """
    inputs = {
        "query": {
            "type": "string",
            "description": "The query. Be sure to pass this as a keyword argument and not a dictionary.",
        },
    }
    output_type = "string"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.data = get_data(download=True)
        self.retriever = self.data.index.as_retriever()

    def forward(self, query: str) -> str:
        assert isinstance(query, str), "Your search query must be a string"

        responses = self.retriever.retrieve(query)

        if len(responses) == 0:
            return "No documents found for this query."
        return "===Document===\n" + "\n===Document===\n".join(
            [
                f"{response.text}\nScore: {response.score}"
                for response in responses
            ]
        )


class SquadQueryTool(Tool):
    name = "squad_query"
    description = """Attempts to answer a question using the Stanford Question Answering Dataset (SQuAD).
        Because this tool does not remember context from previous queries, be sure to include 
        as many details as possible in your query."""
    inputs = {
        "query": {
            "type": "string",
            "description": "The question. Be sure to pass this as a keyword argument and not a dictionary.",
        },
    }
    output_type = "string"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.data = get_data(download=True)
        self.query_engine = self.data.index.as_query_engine()

    def forward(self, query: str) -> str:
        assert isinstance(query, str), "Your search query must be a string"

        response = self.query_engine.query(query)

        if len(response.response) == 0:
            return "No answer found for this query."
        return "Query Response:\n\n" + "\n===Response===\n".join([response.response])