File size: 6,424 Bytes
eae823e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3339a48
8ba343a
eae823e
4ee9429
eae823e
23567be
3339a48
cee33b4
eae823e
23567be
eae823e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23567be
eae823e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c15e08
eae823e
 
 
cf96c24
eae823e
 
 
d984420
 
 
2386a68
d984420
 
2386a68
d984420
 
 
2386a68
d984420
 
2386a68
d984420
eae823e
 
 
 
23567be
eae823e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23567be
eae823e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import streamlit as st
import torch
from llama_index.core import Settings, SimpleDirectoryReader, StorageContext, load_index_from_storage, VectorStoreIndex
from llama_index.core.retrievers import RecursiveRetriever
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.core.chat_engine import CondensePlusContextChatEngine
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.huggingface import HuggingFaceLLM
import os
from transformers import BitsAndBytesConfig

# Configuration for quantization
def configure_quantization():
    return BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    )

# Initialize the LLM
@st.cache_resource
def initialize_llm(hf_token):
    # quantization_config = configure_quantization()
    model_name = 'TinyLlama/TinyLlama_v1.1'
    return HuggingFaceLLM(
        model_name = model_name, #meta-llama/Meta-Llama-3-8B-Instruct meta-llama/Llama-2-7b-chat-hf #google/gemma-7b-it #HuggingFaceH4/zephyr-7b-beta #'GeneZC/MiniChat-2-3B''ericzzz/falcon-rw-1b-chat'
        tokenizer_name = model_name,
        context_window=1900,
        # model_kwargs={"token": hf_token, "quantization_config": quantization_config},
        model_kwargs={"token": hf_token},
        tokenizer_kwargs={"token": hf_token},
        max_new_tokens=300,
        device_map="auto",
    )

# Load data and create index if necessary
def load_or_create_index(embed_model, directories, persist_dir):
    all_nodes = []
    # for directory in directories:
    #     docs = SimpleDirectoryReader(input_dir=directory).load_data()
    #     nodes = Settings.node_parser.get_nodes_from_documents(docs)
    #     all_nodes.extend(nodes)

    if os.path.exists(persist_dir):
        index = load_index_from_storage(StorageContext.from_defaults(persist_dir=persist_dir))
    else:
        index = VectorStoreIndex(all_nodes, embed_model=embed_model)
        index.storage_context.persist(persist_dir=persist_dir)
    return index, all_nodes

# Function to reset chat engine memory
def reset_memory():
    st.session_state.memory.reset()
    st.write("Memory has been reset")

# Function to get current memory size
def get_memory_size():
    chat_history = st.session_state.memory.get_all()
    total_tokens = sum(len(message.content.split()) for message in chat_history)
    return total_tokens

def handle_query(user_prompt, llm):
    # Initialize retriever and chat engine
    vector_retriever_chunk = st.session_state.index.as_retriever(similarity_top_k=2)
    retriever_chunk = RecursiveRetriever(
        "vector",
        retriever_dict={"vector": vector_retriever_chunk},
        node_dict=st.session_state.all_nodes_dict,
        verbose=False,
    )
    
    MEMORY_THRESHOLD = 1900
    
    if 'memory' not in st.session_state:
        st.session_state.memory = ChatMemoryBuffer.from_defaults(token_limit=MEMORY_THRESHOLD)

    chat_engine = CondensePlusContextChatEngine(
        retriever=retriever_chunk,
        memory=st.session_state.memory,
        llm=llm,
        context_prompt=(
            "You are a chatbot, able to have normal friendly interactions, as well as to answer "
            "questions about Malaysia generally. "
            "Here are the relevant documents for the context:\n"
            "{context_str}"
            "\nInstruction: Use the previous chat history, or the context above, to interact and help the user. "
            "If you don't know, please do not make up an answer."
        ),
        node_postprocessors=[MetadataReplacementPostProcessor(target_metadata_key="window")],
        verbose=False,
    )
    response = chat_engine.chat(user_prompt)
    return response

def main():
    hf_token = os.environ.get("HF_TOKEN")
    # hf_token =  '' # Replace with your actual token
    persist_dir = "./vectordb"
    directories = [
        # '/kaggle/input/coursera-course-data',
        # '/kaggle/input/data-scientist-job-webscrape',
        './data'
    ]

    # Initialize LLM and Settings
    # Initialize LLM and Settings
    if 'llm' not in st.session_state:
        llm = initialize_llm(hf_token)
        st.session_state.llm = llm
        Settings.llm = llm
    
    # llm = st.session_state.llm
    
    if 'embed_model' not in st.session_state:
        embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5")
        st.session_state.embed_model = embed_model
        Settings.embed_model = embed_model
        
    # embed_model = st.session_state.embed_model
    
    Settings.chunk_size = 1024

    if 'index' not in st.session_state:
        # Load or create index
        index, all_nodes = load_or_create_index(st.session_state.embed_model, directories, persist_dir)
        st.session_state.index = index
        st.session_state.all_nodes_dict = {n.node_id: n for n in all_nodes}

    if 'memory' not in st.session_state:
        MEMORY_THRESHOLD = 2500
        st.session_state.memory = ChatMemoryBuffer.from_defaults(token_limit=MEMORY_THRESHOLD)

        
    # Streamlit UI
    st.title("Malaysia Q&A Chatbot")
    st.write("Ask me anything about Malaysia, and I'll try my best to help you!")

    if 'messages' not in st.session_state:
        st.session_state.messages = [{'role': 'assistant', "content": 'Hello! I am Bot Axia. How can I help?'}]
    
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = []

    user_prompt = st.chat_input("Ask me anything:")
    if user_prompt:
        st.session_state.messages.append({'role': 'user', "content": user_prompt})
        response = handle_query(user_prompt, st.session_state.llm)
        response = response.response
        st.session_state.messages.append({'role': 'assistant', "content": response})

    for message in st.session_state.messages:
        with st.chat_message(message['role']):
            st.write(message['content'])

    
    st.write("Memory size: ", get_memory_size())
    if get_memory_size() > 1500:
        st.write('Memory exceeded')
        reset_memory()
        
    if st.button("Reset Chat"):
        st.session_state.messages = []
        reset_memory()

if __name__ == "__main__":
    main()