Spaces:
Running
Running
import streamlit as st | |
import torch | |
from llama_index.core import Settings, SimpleDirectoryReader, StorageContext, load_index_from_storage, VectorStoreIndex | |
from llama_index.core.retrievers import RecursiveRetriever | |
from llama_index.core.memory import ChatMemoryBuffer | |
from llama_index.core.response_synthesizers import get_response_synthesizer | |
from llama_index.core.chat_engine import CondensePlusContextChatEngine | |
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.llms.huggingface import HuggingFaceLLM | |
import os | |
from transformers import BitsAndBytesConfig | |
# Configuration for quantization | |
def configure_quantization(): | |
return BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
) | |
# Initialize the LLM | |
def initialize_llm(hf_token): | |
# quantization_config = configure_quantization() | |
model_name = 'TinyLlama/TinyLlama_v1.1' | |
return HuggingFaceLLM( | |
model_name = model_name, #meta-llama/Meta-Llama-3-8B-Instruct meta-llama/Llama-2-7b-chat-hf #google/gemma-7b-it #HuggingFaceH4/zephyr-7b-beta #'GeneZC/MiniChat-2-3B''ericzzz/falcon-rw-1b-chat' | |
tokenizer_name = model_name, | |
context_window=1900, | |
# model_kwargs={"token": hf_token, "quantization_config": quantization_config}, | |
model_kwargs={"token": hf_token}, | |
tokenizer_kwargs={"token": hf_token}, | |
max_new_tokens=300, | |
device_map="auto", | |
) | |
# Load data and create index if necessary | |
def load_or_create_index(embed_model, directories, persist_dir): | |
all_nodes = [] | |
# for directory in directories: | |
# docs = SimpleDirectoryReader(input_dir=directory).load_data() | |
# nodes = Settings.node_parser.get_nodes_from_documents(docs) | |
# all_nodes.extend(nodes) | |
if os.path.exists(persist_dir): | |
index = load_index_from_storage(StorageContext.from_defaults(persist_dir=persist_dir)) | |
else: | |
index = VectorStoreIndex(all_nodes, embed_model=embed_model) | |
index.storage_context.persist(persist_dir=persist_dir) | |
return index, all_nodes | |
# Function to reset chat engine memory | |
def reset_memory(): | |
st.session_state.memory.reset() | |
st.write("Memory has been reset") | |
# Function to get current memory size | |
def get_memory_size(): | |
chat_history = st.session_state.memory.get_all() | |
total_tokens = sum(len(message.content.split()) for message in chat_history) | |
return total_tokens | |
def handle_query(user_prompt, llm): | |
# Initialize retriever and chat engine | |
vector_retriever_chunk = st.session_state.index.as_retriever(similarity_top_k=2) | |
retriever_chunk = RecursiveRetriever( | |
"vector", | |
retriever_dict={"vector": vector_retriever_chunk}, | |
node_dict=st.session_state.all_nodes_dict, | |
verbose=False, | |
) | |
MEMORY_THRESHOLD = 1900 | |
if 'memory' not in st.session_state: | |
st.session_state.memory = ChatMemoryBuffer.from_defaults(token_limit=MEMORY_THRESHOLD) | |
chat_engine = CondensePlusContextChatEngine( | |
retriever=retriever_chunk, | |
memory=st.session_state.memory, | |
llm=llm, | |
context_prompt=( | |
"You are a chatbot, able to have normal friendly interactions, as well as to answer " | |
"questions about Malaysia generally. " | |
"Here are the relevant documents for the context:\n" | |
"{context_str}" | |
"\nInstruction: Use the previous chat history, or the context above, to interact and help the user. " | |
"If you don't know, please do not make up an answer." | |
), | |
node_postprocessors=[MetadataReplacementPostProcessor(target_metadata_key="window")], | |
verbose=False, | |
) | |
response = chat_engine.chat(user_prompt) | |
return response | |
def main(): | |
hf_token = os.environ.get("HF_TOKEN") | |
# hf_token = '' # Replace with your actual token | |
persist_dir = "./vectordb" | |
directories = [ | |
# '/kaggle/input/coursera-course-data', | |
# '/kaggle/input/data-scientist-job-webscrape', | |
'./data' | |
] | |
# Initialize LLM and Settings | |
# Initialize LLM and Settings | |
if 'llm' not in st.session_state: | |
llm = initialize_llm(hf_token) | |
st.session_state.llm = llm | |
Settings.llm = llm | |
# llm = st.session_state.llm | |
if 'embed_model' not in st.session_state: | |
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5") | |
st.session_state.embed_model = embed_model | |
Settings.embed_model = embed_model | |
# embed_model = st.session_state.embed_model | |
Settings.chunk_size = 1024 | |
if 'index' not in st.session_state: | |
# Load or create index | |
index, all_nodes = load_or_create_index(st.session_state.embed_model, directories, persist_dir) | |
st.session_state.index = index | |
st.session_state.all_nodes_dict = {n.node_id: n for n in all_nodes} | |
if 'memory' not in st.session_state: | |
MEMORY_THRESHOLD = 2500 | |
st.session_state.memory = ChatMemoryBuffer.from_defaults(token_limit=MEMORY_THRESHOLD) | |
# Streamlit UI | |
st.title("Malaysia Q&A Chatbot") | |
st.write("Ask me anything about Malaysia, and I'll try my best to help you!") | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [{'role': 'assistant', "content": 'Hello! I am Bot Axia. How can I help?'}] | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
user_prompt = st.chat_input("Ask me anything:") | |
if user_prompt: | |
st.session_state.messages.append({'role': 'user', "content": user_prompt}) | |
response = handle_query(user_prompt, st.session_state.llm) | |
response = response.response | |
st.session_state.messages.append({'role': 'assistant', "content": response}) | |
for message in st.session_state.messages: | |
with st.chat_message(message['role']): | |
st.write(message['content']) | |
st.write("Memory size: ", get_memory_size()) | |
if get_memory_size() > 1500: | |
st.write('Memory exceeded') | |
reset_memory() | |
if st.button("Reset Chat"): | |
st.session_state.messages = [] | |
reset_memory() | |
if __name__ == "__main__": | |
main() |