Diego Staphorst
feat(langchain) knowledge base
9e4d4b4
raw
history blame contribute delete
2.18 kB
from datasets import load_dataset
from datasets import Dataset
import datasets
from tqdm import tqdm
from transformers import AutoTokenizer
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
knowledge_base = datasets.load_dataset("gaia-benchmark/GAIA", '2023_level1', split='test')
print(knowledge_base.column_names)
# ['task_id', 'Question', 'Level', 'Final answer', 'file_name', 'file_path', 'Annotator Metadata']
source_docs = [
Document(
page_content=doc["Question"],
metadata={
"task_id": doc["task_id"],
"level": doc["Level"],
"final_answer": doc["Final answer"],
"file_name": doc["file_name"],
"file_path": doc["file_path"],
"annotator_metadata": doc["Annotator Metadata"],
},
)
for doc in knowledge_base
]
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
AutoTokenizer.from_pretrained("thenlper/gte-small"),
chunk_size=200,
chunk_overlap=20,
add_start_index=True,
strip_whitespace=True,
separators=["\n\n", "\n", ".", " ", ""],
)
# Split docs and keep only unique ones
print("Splitting documents...")
docs_processed = []
unique_texts = {}
for doc in tqdm(source_docs):
new_docs = text_splitter.split_documents([doc])
for new_doc in new_docs:
if new_doc.page_content not in unique_texts:
unique_texts[new_doc.page_content] = True
docs_processed.append(new_doc)
print("Embedding documents... This should take a few minutes (5 minutes on MacBook with M1 Pro)")
embedding_model = HuggingFaceEmbeddings(model_name="thenlper/gte-small")
vectordb = FAISS.from_documents(
documents=docs_processed,
embedding=embedding_model,
distance_strategy=DistanceStrategy.COSINE,
)
if __name__ == "__main__":
# print(dataset)
# ds = Dataset.from_dict(dataset)
# dataset = ds.with_format("pandas")
print(vectordb)