RichardHu commited on
Commit
70b1a35
Β·
verified Β·
1 Parent(s): 7fce2c7

Update retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +36 -11
retriever.py CHANGED
@@ -50,16 +50,41 @@
50
  # return GuestInfoRetrieverTool(docs)
51
 
52
 
53
- from langchain_community.vectorstores import Chroma
54
- from langchain_openai import OpenAIEmbeddings
55
-
56
- def get_retriever():
57
- """εˆ›ε»ΊεΉΆθΏ”ε›žζ£€η΄’ε™¨"""
58
- embeddings = OpenAIEmbeddings()
59
- vectorstore = Chroma(
60
- embedding_function=embeddings,
61
- persist_directory="./chroma_db",
62
- collection_name="rag_docs"
 
 
 
 
 
 
 
 
63
  )
64
- return vectorstore.as_retriever(search_kwargs={"k": 5})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
 
50
  # return GuestInfoRetrieverTool(docs)
51
 
52
 
53
+ import datasets
54
+ from langchain.docstore.document import Document
55
+ from langchain_community.retrievers import BM25Retriever
56
+ from langchain.tools import Tool
57
+
58
+ # Load the dataset
59
+ guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
60
+
61
+ # Convert dataset entries into Document objects
62
+ docs = [
63
+ Document(
64
+ page_content="\n".join([
65
+ f"Name: {guest['name']}",
66
+ f"Relation: {guest['relation']}",
67
+ f"Description: {guest['description']}",
68
+ f"Email: {guest['email']}"
69
+ ]),
70
+ metadata={"name": guest["name"]}
71
  )
72
+ for guest in guest_dataset
73
+ ]
74
+
75
+ bm25_retriever = BM25Retriever.from_documents(docs)
76
+
77
+ def extract_text(query: str) -> str:
78
+ """Retrieves detailed information about gala guests based on their name or relation."""
79
+ results = bm25_retriever.invoke(query)
80
+ if results:
81
+ return "\n\n".join([doc.page_content for doc in results[:3]])
82
+ else:
83
+ return "No matching guest information found."
84
+
85
+ guest_info_tool = Tool(
86
+ name="guest_info_retriever",
87
+ func=extract_text,
88
+ description="Retrieves detailed information about gala guests based on their name or relation."
89
+ )
90