davemasino commited on
Commit
3abe8a3
·
1 Parent(s): ed6f788

First iteration of retriever based on lesson

Browse files
Files changed (1) hide show
  1. retriever.py +47 -0
retriever.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ from langchain_community.retrievers import BM25Retriever
3
+ from langchain.docstore.document import Document
4
+ import datasets
5
+
6
+
7
+ # Load the dataset
8
+ guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
9
+
10
+ # Convert dataset entries into Document objects
11
+ docs = [
12
+ Document(
13
+ page_content="\n".join([
14
+ f"Name: {guest['name']}",
15
+ f"Relation: {guest['relation']}",
16
+ f"Description: {guest['description']}",
17
+ f"Email: {guest['email']}"
18
+ ]),
19
+ metadata={"name": guest["name"]}
20
+ )
21
+ for guest in guest_dataset
22
+ ]
23
+
24
+ class GuestInfoRetrieverTool(Tool):
25
+ name = "guest_info_retriever"
26
+ description = "Retrieves detailed information about gala guests based on their name or relation."
27
+ inputs = {
28
+ "query": {
29
+ "type": "string",
30
+ "description": "The name or relation of the guest you want information about."
31
+ }
32
+ }
33
+ output_type = "string"
34
+
35
+ def __init__(self, docs):
36
+ self.is_initialized = False
37
+ self.retriever = BM25Retriever.from_documents(docs)
38
+
39
+ def forward(self, query: str):
40
+ results = self.retriever.get_relevant_documents(query)
41
+ if results:
42
+ return "\n\n".join([doc.page_content for doc in results[:3]])
43
+ else:
44
+ return "No matching guest information found."
45
+
46
+ # Initialize the tool
47
+ guest_info_tool = GuestInfoRetrieverTool(docs)