Spaces:
Sleeping
Sleeping
from typing import List, Optional | |
from uuid import uuid4 | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from transformers import AutoTokenizer | |
from langchain.docstore.document import Document as LangchainDocument | |
from tqdm import tqdm | |
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.vectorstores.utils import DistanceStrategy | |
from z_document_reader import read_wiki_html | |
EMBEDDING_MODEL_NAME = "thenlper/gte-small" | |
def get_embedding_model(): | |
embedding_model = HuggingFaceEmbeddings( | |
model_name=EMBEDDING_MODEL_NAME, | |
multi_process=False, | |
model_kwargs={"device": "cpu"}, | |
encode_kwargs={"normalize_embeddings": True}, # Set `True` for cosine similarity | |
) | |
return embedding_model | |
def split_documents( | |
chunk_size: int, | |
knowledge_base: List[LangchainDocument], | |
tokenizer_name: Optional[str] = EMBEDDING_MODEL_NAME, | |
) -> List[LangchainDocument]: | |
""" | |
Split documents into chunks of maximum size `chunk_size` tokens and return a list of documents. | |
""" | |
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( | |
AutoTokenizer.from_pretrained(tokenizer_name), | |
chunk_size=chunk_size, | |
chunk_overlap=int(chunk_size / 10), | |
add_start_index=True, | |
strip_whitespace=True, | |
# separators=MARKDOWN_SEPARATORS, | |
) | |
docs_processed = [] | |
for doc in knowledge_base: | |
docs_processed += text_splitter.split_documents([doc]) | |
# Remove duplicates | |
unique_texts = {} | |
docs_processed_unique = [] | |
for doc in docs_processed: | |
if doc.page_content not in unique_texts: | |
unique_texts[doc.page_content] = True | |
docs_processed_unique.append(doc) | |
return docs_processed_unique | |
def construct_vector_db(docs_processed, emb_model): | |
vdb = FAISS.from_documents( | |
docs_processed, emb_model, distance_strategy=DistanceStrategy.COSINE | |
) | |
return vdb | |
def get_data_files(location:str ="_data/") -> list: | |
""" | |
Returns html file paths | |
""" | |
from glob import glob | |
files = glob(location + "*.html") | |
files += glob(location + "*.htm") | |
return files | |
def generate_and_save_vector_store(vector_store_location:str="cache_vector_store"): | |
""" | |
One time function to create and store vector | |
""" | |
data_files = get_data_files() | |
TEXT_KBs, IMAGE_KBs = list(), list() | |
for file in data_files: | |
TEXT_KB, IMAGE_KB = read_wiki_html(file) | |
TEXT_KBs.extend(TEXT_KB) | |
IMAGE_KBs.extend(IMAGE_KB) | |
# | |
docs_text_processed = split_documents( | |
512, # We choose a chunk size adapted to our model | |
TEXT_KBs, | |
tokenizer_name=EMBEDDING_MODEL_NAME, | |
) | |
docs_imgs_processed = split_documents( | |
512, # We choose a chunk size adapted to our model | |
IMAGE_KBs, | |
tokenizer_name=EMBEDDING_MODEL_NAME, | |
) | |
emb_model = get_embedding_model() | |
vector_store_text = construct_vector_db(docs_text_processed, emb_model) | |
vector_store_images = construct_vector_db(docs_imgs_processed, emb_model) | |
vector_store_text.save_local(vector_store_location+"_text") | |
vector_store_images.save_local(vector_store_location+"_images") | |
def load_vector_store(vector_store_location:str="cache_vector_store"): | |
'''Returns two vector stores; one for text and another for image | |
''' | |
emb_model = get_embedding_model() | |
vs_text = FAISS.load_local( | |
vector_store_location+"_text", emb_model, allow_dangerous_deserialization=True | |
) | |
vs_image = FAISS.load_local( | |
vector_store_location+"_images", emb_model, allow_dangerous_deserialization=True | |
) | |
return vs_text, vs_image | |
if __name__ == "__main__": | |
generate_and_save_vector_store() | |
load_vector_store() | |
pass |