File size: 4,178 Bytes
2fe32bb
 
 
 
 
 
 
 
9222de1
2fe32bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import List, Optional
from uuid import uuid4
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import AutoTokenizer
from langchain.docstore.document import Document as LangchainDocument
from tqdm import tqdm

from langchain.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy

from z_document_reader import read_wiki_html

EMBEDDING_MODEL_NAME = "thenlper/gte-small"

def get_embedding_model():

    embedding_model = HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL_NAME,
        multi_process=True,
        model_kwargs={"device": "cpu"},
        encode_kwargs={"normalize_embeddings": True},  # Set `True` for cosine similarity
    )
    return embedding_model

def split_documents(
    chunk_size: int,
    knowledge_base: List[LangchainDocument],
    tokenizer_name: Optional[str] = EMBEDDING_MODEL_NAME,
) -> List[LangchainDocument]:
    """
    Split documents into chunks of maximum size `chunk_size` tokens and return a list of documents.
    """
    text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
        AutoTokenizer.from_pretrained(tokenizer_name),
        chunk_size=chunk_size,
        chunk_overlap=int(chunk_size / 10),
        add_start_index=True,
        strip_whitespace=True,
        # separators=MARKDOWN_SEPARATORS,
    )

    docs_processed = []
    for doc in knowledge_base:
        docs_processed += text_splitter.split_documents([doc])

    # Remove duplicates
    unique_texts = {}
    docs_processed_unique = []
    for doc in docs_processed:
        if doc.page_content not in unique_texts:
            unique_texts[doc.page_content] = True
            docs_processed_unique.append(doc)

    return docs_processed_unique

def construct_vector_db(docs_processed, emb_model):
    vdb = FAISS.from_documents(
        docs_processed, emb_model, distance_strategy=DistanceStrategy.COSINE    
    )
    return vdb
    # from langchain_chroma import Chroma
    # vector_store = Chroma(
    #     collection_name="example_collection",
    #     embedding_function=emb_model,
    #     persist_directory="./chroma_langchain_db",  # Where to save data locally, remove if not necessary
    # )
    # return vector_store

def get_data_files(location:str ="_data/") -> list:
    """
    Returns html file paths
    """
    from glob import glob 
    files = glob(location + "*.html")
    files += glob(location + "*.htm")
    return files

def generate_and_save_vector_store(vector_store_location:str="cache_vector_store"):
    """
    One time function to create and store vector 
    """
    data_files = get_data_files()
    TEXT_KBs, IMAGE_KBs = list(), list() 
    for file in data_files: 
        TEXT_KB, IMAGE_KB = read_wiki_html(file) 
        TEXT_KBs.extend(TEXT_KB) 
        IMAGE_KBs.extend(IMAGE_KB)
    
    #
    docs_text_processed = split_documents(
        512,  # We choose a chunk size adapted to our model
        TEXT_KBs,
        tokenizer_name=EMBEDDING_MODEL_NAME,
    )
    docs_imgs_processed = split_documents(
        512,  # We choose a chunk size adapted to our model
        IMAGE_KBs,
        tokenizer_name=EMBEDDING_MODEL_NAME,
    )

    emb_model = get_embedding_model() 

    vector_store_text = construct_vector_db(docs_text_processed, emb_model)
    vector_store_images = construct_vector_db(docs_imgs_processed, emb_model)

    vector_store_text.save_local(vector_store_location+"_text")
    vector_store_images.save_local(vector_store_location+"_images")

def load_vector_store(vector_store_location:str="cache_vector_store"):
    '''Returns two vector stores; one for text and another for image
    '''
    emb_model = get_embedding_model()

    vs_text = FAISS.load_local(
        vector_store_location+"_text", emb_model, allow_dangerous_deserialization=True
    )
    vs_image = FAISS.load_local(
        vector_store_location+"_images", emb_model, allow_dangerous_deserialization=True
    )

    return vs_text, vs_image

if __name__ == "__main__":
    # generate_and_save_vector_store()
    load_vector_store()
    pass