File size: 2,825 Bytes
f2f8da5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fd0db3
38812af
f2f8da5
 
 
 
 
 
 
38812af
f2f8da5
 
 
38812af
f2f8da5
 
 
 
 
 
 
 
 
 
 
 
 
38812af
f2f8da5
 
38812af
 
70b1a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2f8da5
70b1a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38812af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# from smolagents import Tool
# from langchain_community.retrievers import BM25Retriever
# from langchain.docstore.document import Document
# import datasets


# class GuestInfoRetrieverTool(Tool):
#     name = "guest_info_retriever"
#     description = "Retrieves detailed information about gala guests based on their name or relation."
#     inputs = {
#         "query": {
#             "type": "string",
#             "description": "The name or relation of the guest you want information about."
#         }
#     }
#     output_type = "string"

#     def __init__(self, docs):
#         self.is_initialized = False
#         self.retriever = BM25Retriever.from_documents(docs)
       

#     def forward(self, query: str):
#         results = self.retriever.get_relevant_documents(query)
#         if results:
#             return "\n\n".join([doc.page_content for doc in results[:3]])
#         else:
#             return "No matching guest information found."


# def load_guest_dataset():
#     # Load the dataset
#     guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")

#     # Convert dataset entries into Document objects
#     docs = [
#         Document(
#             page_content="\n".join([
#                 f"Name: {guest['name']}",
#                 f"Relation: {guest['relation']}",
#                 f"Description: {guest['description']}",
#                 f"Email: {guest['email']}"
#             ]),
#             metadata={"name": guest["name"]}
#         )
#         for guest in guest_dataset
#     ]

#     # Return the tool
#     return GuestInfoRetrieverTool(docs)


import datasets
from langchain.docstore.document import Document
from langchain_community.retrievers import BM25Retriever
from langchain.tools import Tool

# Load the dataset
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")

# Convert dataset entries into Document objects
docs = [
    Document(
        page_content="\n".join([
            f"Name: {guest['name']}",
            f"Relation: {guest['relation']}",
            f"Description: {guest['description']}",
            f"Email: {guest['email']}"
        ]),
        metadata={"name": guest["name"]}
    )
    for guest in guest_dataset
]

bm25_retriever = BM25Retriever.from_documents(docs)

def extract_text(query: str) -> str:
    """Retrieves detailed information about gala guests based on their name or relation."""
    results = bm25_retriever.invoke(query)
    if results:
        return "\n\n".join([doc.page_content for doc in results[:3]])
    else:
        return "No matching guest information found."

guest_info_tool = Tool(
    name="guest_info_retriever",
    func=extract_text,
    description="Retrieves detailed information about gala guests based on their name or relation."
)