MikeMann's picture
vectorstore
5b61faf
raw
history blame
7.48 kB
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Disable CUDA initialization
os.environ["allow_dangerous_deserialization"] = "True"
print(os.getcwd())
embedding_path="/home/user/app/docs/_embeddings/index.faiss"
print(f"Loading FAISS index from: {embedding_path}")
if not os.path.exists(embedding_path):
print("File not found!")
HF_KEY=os.getenv('Gated_Repo')
import spaces
import sys
import time
import re
import threading
from typing import List, Dict
import torch
import gradio as gr
from langchain_community.docstore import InMemoryDocstore
from langchain_community.document_loaders import TextLoader
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.docstore.document import Document as LangchainDocument
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores.utils import DistanceStrategy
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer, pipeline
from huggingface_hub import login
login(token=HF_KEY)
class BSIChatbot:
def __init__(self, model_paths: Dict[str, str], docs_path: str):
self.embedding_model = None
self.llmpipeline = None
self.llmtokenizer = None
self.vectorstore = None
self.reranking_model = None
self.streamer = None
self.images = [None]
self.llm_path = model_paths['llm_path']
self.word_and_embed_model_path = model_paths['embed_model_path']
self.docs = docs_path
self.rerank_model_path = model_paths['rerank_model_path']
@spaces.GPU
def initialize_embedding_model(self, rebuild_embeddings: bool):
raw_knowledge_base = []
# Initialize embedding model
self.embedding_model = HuggingFaceEmbeddings(
model_name=self.word_and_embed_model_path,
multi_process=True,
model_kwargs={"device": "cuda"},
encode_kwargs={"normalize_embeddings": True},
)
if rebuild_embeddings:
# Load documents
for doc in os.listdir(self.docs):
file_path = os.path.join(self.docs, doc)
if doc.endswith(".md") or doc.endswith(".txt"):
with open(file_path, 'r', encoding='utf-8' if doc.endswith(".md") else 'cp1252') as file:
content = file.read()
metadata = {"source": doc}
raw_knowledge_base.append(LangchainDocument(page_content=content, metadata=metadata))
# Split documents into chunks
tokenizer = AutoTokenizer.from_pretrained(self.word_and_embed_model_path)
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer=tokenizer,
chunk_size=512,
chunk_overlap=0,
add_start_index=True,
strip_whitespace=True,
)
processed_docs = []
for doc in raw_knowledge_base:
chunks = text_splitter.split_documents([doc])
for chunk in chunks:
chunk.metadata.update({"source": doc.metadata['source']})
processed_docs.extend(chunks)
# Create and save vector store
self.vectorstore = FAISS.from_documents(processed_docs, self.embedding_model, distance_strategy=DistanceStrategy.COSINE)
self.vectorstore.save_local(os.path.join(self.docs, "_embeddings"))
else:
# Load existing vector store
self.vectorstore = FAISS.load_local(os.path.join(self.docs, "_embeddings"), self.embedding_model, allow_dangerous_deserialization=True)
print("DBG: Vectorstore Status Initialization:", self.vectorstore)
@spaces.GPU
def retrieve_similar_embedding(self, query: str):
#lazy load
#if (self.vectorstore == None):
# self.vectorstore = FAISS.load_local(os.path.join(self.docs, "_embeddings"), self.embedding_model,
# allow_dangerous_deserialization=True)
print("DBG: Vectorstore Status retriever:", self.vectorstore)
query = f"Instruct: Given a search query, retrieve the relevant passages that answer the query\nQuery:{query}"
return self.vectorstore.similarity_search(query=query, k=20)
@spaces.GPU
def initialize_llm(self):
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
llm = AutoModelForCausalLM.from_pretrained(self.llm_path, quantization_config=bnb_config)
#llm = AutoModelForCausalLM.from_pretrained(self.llm_path)
self.llmtokenizer = AutoTokenizer.from_pretrained(self.llm_path)
self.streamer = TextIteratorStreamer(self.llmtokenizer, skip_prompt=True)
self.llmpipeline = pipeline(
model=llm,
tokenizer=self.llmtokenizer,
task="text-generation",
do_sample=True,
temperature=0.7,
repetition_penalty=1.1,
return_full_text=False,
streamer=self.streamer,
max_new_tokens=500,
)
@spaces.GPU
def rag_prompt(self, query: str, rerank: bool, history: List[Dict]):
retrieved_chunks = self.retrieve_similar_embedding(query)
retrieved_texts = [f"{chunk.metadata['source']}:\n{chunk.page_content}" for chunk in retrieved_chunks]
if rerank and self.reranking_model:
retrieved_texts = self.reranking_model.rerank(query, retrieved_texts, k=5)
context = "\n".join(retrieved_texts)
history_text = "\n".join([h['content'] for h in history])
final_prompt = f"""Context:
{context}
---
History:
{history_text}
---
Question: {query}"""
generation_thread = threading.Thread(target=self.llmpipeline, args=(final_prompt,))
generation_thread.start()
return self.streamer
def launch_interface(self):
with gr.Blocks() as demo:
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox()
clear = gr.Button("Clear")
reset = gr.Button("Reset")
def user_input(user_message, history):
return "", history + [{"role": "user", "content": user_message}]
def bot_response(history):
response = self.rag_prompt(history[-1]['content'], True, history)
history.append({"role": "assistant", "content": ""})
for token in response:
history[-1]['content'] += token
yield history
msg.submit(user_input, [msg, chatbot], [msg, chatbot]).then(bot_response, chatbot, chatbot)
clear.click(lambda: None, None, chatbot)
reset.click(lambda: [], outputs=chatbot)
demo.launch()
if __name__ == '__main__':
model_paths = {
'llm_path': 'meta-llama/Llama-3.2-3B-Instruct',
'embed_model_path': 'intfloat/multilingual-e5-large-instruct',
'rerank_model_path': 'domci/ColBERTv2-mmarco-de-0.1'
}
docs_path = '/home/user/app/docs'
bot = BSIChatbot(model_paths, docs_path)
bot.initialize_embedding_model(rebuild_embeddings=False)
bot.initialize_llm()
bot.launch_interface()