import streamlit as st import torch from llama_index.core import Settings, SimpleDirectoryReader, StorageContext, load_index_from_storage, VectorStoreIndex from llama_index.core.retrievers import RecursiveRetriever from llama_index.core.memory import ChatMemoryBuffer from llama_index.core.response_synthesizers import get_response_synthesizer from llama_index.core.chat_engine import CondensePlusContextChatEngine from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.llms.huggingface import HuggingFaceLLM import os from transformers import BitsAndBytesConfig # Configuration for quantization def configure_quantization(): return BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, ) # Initialize the LLM @st.cache_resource def initialize_llm(hf_token): # quantization_config = configure_quantization() model_name = 'TinyLlama/TinyLlama_v1.1' return HuggingFaceLLM( model_name = model_name, #meta-llama/Meta-Llama-3-8B-Instruct meta-llama/Llama-2-7b-chat-hf #google/gemma-7b-it #HuggingFaceH4/zephyr-7b-beta #'GeneZC/MiniChat-2-3B''ericzzz/falcon-rw-1b-chat' tokenizer_name = model_name, context_window=1900, # model_kwargs={"token": hf_token, "quantization_config": quantization_config}, model_kwargs={"token": hf_token}, tokenizer_kwargs={"token": hf_token}, max_new_tokens=300, device_map="auto", ) # Load data and create index if necessary def load_or_create_index(embed_model, directories, persist_dir): all_nodes = [] # for directory in directories: # docs = SimpleDirectoryReader(input_dir=directory).load_data() # nodes = Settings.node_parser.get_nodes_from_documents(docs) # all_nodes.extend(nodes) if os.path.exists(persist_dir): index = load_index_from_storage(StorageContext.from_defaults(persist_dir=persist_dir)) else: index = VectorStoreIndex(all_nodes, embed_model=embed_model) index.storage_context.persist(persist_dir=persist_dir) return index, all_nodes # Function to reset chat engine memory def reset_memory(): st.session_state.memory.reset() st.write("Memory has been reset") # Function to get current memory size def get_memory_size(): chat_history = st.session_state.memory.get_all() total_tokens = sum(len(message.content.split()) for message in chat_history) return total_tokens def handle_query(user_prompt, llm): # Initialize retriever and chat engine vector_retriever_chunk = st.session_state.index.as_retriever(similarity_top_k=2) retriever_chunk = RecursiveRetriever( "vector", retriever_dict={"vector": vector_retriever_chunk}, node_dict=st.session_state.all_nodes_dict, verbose=False, ) MEMORY_THRESHOLD = 1900 if 'memory' not in st.session_state: st.session_state.memory = ChatMemoryBuffer.from_defaults(token_limit=MEMORY_THRESHOLD) chat_engine = CondensePlusContextChatEngine( retriever=retriever_chunk, memory=st.session_state.memory, llm=llm, context_prompt=( "You are a chatbot, able to have normal friendly interactions, as well as to answer " "questions about Malaysia generally. " "Here are the relevant documents for the context:\n" "{context_str}" "\nInstruction: Use the previous chat history, or the context above, to interact and help the user. " "If you don't know, please do not make up an answer." ), node_postprocessors=[MetadataReplacementPostProcessor(target_metadata_key="window")], verbose=False, ) response = chat_engine.chat(user_prompt) return response def main(): hf_token = os.environ.get("HF_TOKEN") # hf_token = '' # Replace with your actual token persist_dir = "./vectordb" directories = [ # '/kaggle/input/coursera-course-data', # '/kaggle/input/data-scientist-job-webscrape', './data' ] # Initialize LLM and Settings # Initialize LLM and Settings if 'llm' not in st.session_state: llm = initialize_llm(hf_token) st.session_state.llm = llm Settings.llm = llm # llm = st.session_state.llm if 'embed_model' not in st.session_state: embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5") st.session_state.embed_model = embed_model Settings.embed_model = embed_model # embed_model = st.session_state.embed_model Settings.chunk_size = 1024 if 'index' not in st.session_state: # Load or create index index, all_nodes = load_or_create_index(st.session_state.embed_model, directories, persist_dir) st.session_state.index = index st.session_state.all_nodes_dict = {n.node_id: n for n in all_nodes} if 'memory' not in st.session_state: MEMORY_THRESHOLD = 2500 st.session_state.memory = ChatMemoryBuffer.from_defaults(token_limit=MEMORY_THRESHOLD) # Streamlit UI st.title("Malaysia Q&A Chatbot") st.write("Ask me anything about Malaysia, and I'll try my best to help you!") if 'messages' not in st.session_state: st.session_state.messages = [{'role': 'assistant', "content": 'Hello! I am Bot Axia. How can I help?'}] if "chat_history" not in st.session_state: st.session_state.chat_history = [] user_prompt = st.chat_input("Ask me anything:") if user_prompt: st.session_state.messages.append({'role': 'user', "content": user_prompt}) response = handle_query(user_prompt, st.session_state.llm) response = response.response st.session_state.messages.append({'role': 'assistant', "content": response}) for message in st.session_state.messages: with st.chat_message(message['role']): st.write(message['content']) st.write("Memory size: ", get_memory_size()) if get_memory_size() > 1500: st.write('Memory exceeded') reset_memory() if st.button("Reset Chat"): st.session_state.messages = [] reset_memory() if __name__ == "__main__": main()