Spaces:
Running
Running
File size: 2,324 Bytes
a39d9ba 3a2b6a9 a39d9ba 3a2b6a9 a39d9ba 45c51b6 7a79fec a39d9ba 3a2b6a9 a39d9ba 3a2b6a9 0bd96a1 a39d9ba 3a2b6a9 a39d9ba 3a2b6a9 a39d9ba 3a2b6a9 a39d9ba 3a2b6a9 a39d9ba 3a2b6a9 a39d9ba 3a2b6a9 a39d9ba 3a2b6a9 3658bab a39d9ba 3a2b6a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import zipfile
from huggingface_hub import hf_hub_download
import gradio as gr
from sentence_transformers import SentenceTransformer
from langchain_chroma import Chroma
# Step 1: Download and Extract the Chroma Vector Store
def prepare_chroma_db(hf_token=None):
persist_directory = "chroma_db"
if not os.path.exists(persist_directory):
print("Downloading chroma_db.zip from the dataset repository...")
zip_path = hf_hub_download(
repo_id="camiellia/phapdien_demo", # dataset repository
repo_type="dataset",
filename="chroma_db.zip",
token=hf_token
)
print(f"Downloaded to {zip_path}")
# Extract the zip file into the persist_directory
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(persist_directory)
print(f"Extracted chroma_db to ./{persist_directory}")
else:
print(f"{persist_directory} directory already exists.")
return persist_directory
persist_directory = prepare_chroma_db()
# Step 2: wrapper
class SentenceTransformerWrapper:
def __init__(self, model_name):
self.model = SentenceTransformer(model_name)
def embed_documents(self, texts):
# Convert the list of texts to embeddings
return self.model.encode(texts, show_progress_bar=True).tolist()
def embed_query(self, text):
# Convert a single query to its embedding
return self.model.encode(text).tolist()
embedding_model = SentenceTransformerWrapper('bkai-foundation-models/vietnamese-bi-encoder')
# Step 3: Load the vector store from the directory
vector_db = Chroma(
persist_directory=persist_directory,
embedding_function=embedding_model # Use your SentenceTransformerWrapper instance
)
# Step 4: Gradio function
def retrieve_info(query, k):
results = vector_db.similarity_search(query, k)
output = ""
for i, doc in enumerate(results):
output += f"Result {i+1}:\nMetadata: {doc.metadata}\nContent: {doc.page_content[:1000]}\n\n"
return output
# Step 5: Launch the Gradio interface
demo = gr.Interface(
fn=retrieve_info,
inputs=["text", gr.Number(label="k (Number of chunks to retrieve)")],
outputs=[gr.Textbox(label="Output chunk(s)", lines=25)],
)
demo.launch()
|