File size: 2,324 Bytes
a39d9ba
 
 
3a2b6a9
 
a39d9ba
3a2b6a9
a39d9ba
 
 
 
 
 
45c51b6
7a79fec
a39d9ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a2b6a9
a39d9ba
3a2b6a9
0bd96a1
a39d9ba
 
 
 
3a2b6a9
a39d9ba
3a2b6a9
 
 
 
a39d9ba
3a2b6a9
a39d9ba
3a2b6a9
 
 
a39d9ba
 
3a2b6a9
a39d9ba
3a2b6a9
a39d9ba
 
 
 
3a2b6a9
 
3658bab
a39d9ba
3a2b6a9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import zipfile
from huggingface_hub import hf_hub_download
import gradio as gr
from sentence_transformers import SentenceTransformer
from langchain_chroma import Chroma

# Step 1: Download and Extract the Chroma Vector Store
def prepare_chroma_db(hf_token=None):
    persist_directory = "chroma_db"
    if not os.path.exists(persist_directory):
        print("Downloading chroma_db.zip from the dataset repository...")
        zip_path = hf_hub_download(
            repo_id="camiellia/phapdien_demo",  # dataset repository
            repo_type="dataset",
            filename="chroma_db.zip",
            token=hf_token
        )
        print(f"Downloaded to {zip_path}")
        
        # Extract the zip file into the persist_directory
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(persist_directory)
        print(f"Extracted chroma_db to ./{persist_directory}")
    else:
        print(f"{persist_directory} directory already exists.")
    return persist_directory

persist_directory = prepare_chroma_db()

# Step 2: wrapper
class SentenceTransformerWrapper:
    def __init__(self, model_name):
        self.model = SentenceTransformer(model_name)
        
    def embed_documents(self, texts):
        # Convert the list of texts to embeddings
        return self.model.encode(texts, show_progress_bar=True).tolist()
    
    def embed_query(self, text):
        # Convert a single query to its embedding
        return self.model.encode(text).tolist()

embedding_model = SentenceTransformerWrapper('bkai-foundation-models/vietnamese-bi-encoder')

# Step 3: Load the vector store from the directory
vector_db = Chroma(
    persist_directory=persist_directory,
    embedding_function=embedding_model  # Use your SentenceTransformerWrapper instance
)

# Step 4: Gradio function
def retrieve_info(query, k):
    results = vector_db.similarity_search(query, k)
    output = ""
    for i, doc in enumerate(results):
        output += f"Result {i+1}:\nMetadata: {doc.metadata}\nContent: {doc.page_content[:1000]}\n\n"
    return output

# Step 5: Launch the Gradio interface
demo = gr.Interface(
    fn=retrieve_info,
    inputs=["text", gr.Number(label="k (Number of chunks to retrieve)")],
    outputs=[gr.Textbox(label="Output chunk(s)", lines=25)],
)

demo.launch()