Spaces:
Sleeping
Sleeping
File size: 6,424 Bytes
eae823e 3339a48 8ba343a eae823e 4ee9429 eae823e 23567be 3339a48 cee33b4 eae823e 23567be eae823e 23567be eae823e 0c15e08 eae823e cf96c24 eae823e d984420 2386a68 d984420 2386a68 d984420 2386a68 d984420 2386a68 d984420 eae823e 23567be eae823e 23567be eae823e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import streamlit as st
import torch
from llama_index.core import Settings, SimpleDirectoryReader, StorageContext, load_index_from_storage, VectorStoreIndex
from llama_index.core.retrievers import RecursiveRetriever
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core.chat_engine import CondensePlusContextChatEngine
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
import os
from transformers import BitsAndBytesConfig
# Configuration for quantization
def configure_quantization():
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
# Initialize the LLM
@st.cache_resource
def initialize_llm(hf_token):
# quantization_config = configure_quantization()
model_name = 'TinyLlama/TinyLlama_v1.1'
return HuggingFaceLLM(
model_name = model_name, #meta-llama/Meta-Llama-3-8B-Instruct meta-llama/Llama-2-7b-chat-hf #google/gemma-7b-it #HuggingFaceH4/zephyr-7b-beta #'GeneZC/MiniChat-2-3B''ericzzz/falcon-rw-1b-chat'
tokenizer_name = model_name,
context_window=1900,
# model_kwargs={"token": hf_token, "quantization_config": quantization_config},
model_kwargs={"token": hf_token},
tokenizer_kwargs={"token": hf_token},
max_new_tokens=300,
device_map="auto",
)
# Load data and create index if necessary
def load_or_create_index(embed_model, directories, persist_dir):
all_nodes = []
# for directory in directories:
# docs = SimpleDirectoryReader(input_dir=directory).load_data()
# nodes = Settings.node_parser.get_nodes_from_documents(docs)
# all_nodes.extend(nodes)
if os.path.exists(persist_dir):
index = load_index_from_storage(StorageContext.from_defaults(persist_dir=persist_dir))
else:
index = VectorStoreIndex(all_nodes, embed_model=embed_model)
index.storage_context.persist(persist_dir=persist_dir)
return index, all_nodes
# Function to reset chat engine memory
def reset_memory():
st.session_state.memory.reset()
st.write("Memory has been reset")
# Function to get current memory size
def get_memory_size():
chat_history = st.session_state.memory.get_all()
total_tokens = sum(len(message.content.split()) for message in chat_history)
return total_tokens
def handle_query(user_prompt, llm):
# Initialize retriever and chat engine
vector_retriever_chunk = st.session_state.index.as_retriever(similarity_top_k=2)
retriever_chunk = RecursiveRetriever(
"vector",
retriever_dict={"vector": vector_retriever_chunk},
node_dict=st.session_state.all_nodes_dict,
verbose=False,
)
MEMORY_THRESHOLD = 1900
if 'memory' not in st.session_state:
st.session_state.memory = ChatMemoryBuffer.from_defaults(token_limit=MEMORY_THRESHOLD)
chat_engine = CondensePlusContextChatEngine(
retriever=retriever_chunk,
memory=st.session_state.memory,
llm=llm,
context_prompt=(
"You are a chatbot, able to have normal friendly interactions, as well as to answer "
"questions about Malaysia generally. "
"Here are the relevant documents for the context:\n"
"{context_str}"
"\nInstruction: Use the previous chat history, or the context above, to interact and help the user. "
"If you don't know, please do not make up an answer."
),
node_postprocessors=[MetadataReplacementPostProcessor(target_metadata_key="window")],
verbose=False,
)
response = chat_engine.chat(user_prompt)
return response
def main():
hf_token = os.environ.get("HF_TOKEN")
# hf_token = '' # Replace with your actual token
persist_dir = "./vectordb"
directories = [
# '/kaggle/input/coursera-course-data',
# '/kaggle/input/data-scientist-job-webscrape',
'./data'
]
# Initialize LLM and Settings
# Initialize LLM and Settings
if 'llm' not in st.session_state:
llm = initialize_llm(hf_token)
st.session_state.llm = llm
Settings.llm = llm
# llm = st.session_state.llm
if 'embed_model' not in st.session_state:
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5")
st.session_state.embed_model = embed_model
Settings.embed_model = embed_model
# embed_model = st.session_state.embed_model
Settings.chunk_size = 1024
if 'index' not in st.session_state:
# Load or create index
index, all_nodes = load_or_create_index(st.session_state.embed_model, directories, persist_dir)
st.session_state.index = index
st.session_state.all_nodes_dict = {n.node_id: n for n in all_nodes}
if 'memory' not in st.session_state:
MEMORY_THRESHOLD = 2500
st.session_state.memory = ChatMemoryBuffer.from_defaults(token_limit=MEMORY_THRESHOLD)
# Streamlit UI
st.title("Malaysia Q&A Chatbot")
st.write("Ask me anything about Malaysia, and I'll try my best to help you!")
if 'messages' not in st.session_state:
st.session_state.messages = [{'role': 'assistant', "content": 'Hello! I am Bot Axia. How can I help?'}]
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
user_prompt = st.chat_input("Ask me anything:")
if user_prompt:
st.session_state.messages.append({'role': 'user', "content": user_prompt})
response = handle_query(user_prompt, st.session_state.llm)
response = response.response
st.session_state.messages.append({'role': 'assistant', "content": response})
for message in st.session_state.messages:
with st.chat_message(message['role']):
st.write(message['content'])
st.write("Memory size: ", get_memory_size())
if get_memory_size() > 1500:
st.write('Memory exceeded')
reset_memory()
if st.button("Reset Chat"):
st.session_state.messages = []
reset_memory()
if __name__ == "__main__":
main() |