File size: 3,896 Bytes
4cfde22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# app.py
import gradio as gr
import os
from langchain_community.vectorstores import FAISS
from langchain_together import TogetherEmbeddings
from operator import itemgetter
from langchain.memory import ConversationBufferMemory
from langchain.schema import format_document
from typing import List, Tuple

# Environment variables for API keys
TOGETHER_API_KEY = os.getenv('TOGETHER_API_KEY')

class ChatBot:
    def __init__(self):
        # Load the pre-created FAISS index
        self.vectorstore = FAISS.load_local("faiss_index")
        self.retriever = self.vectorstore.as_retriever()
        
        # Initialize the model
        self.model = Together(
            model="meta-llama/Llama-3.3-70B-Instruct-Turbo",
            temperature=0.7,
            max_tokens=128,
            top_k=50,
            together_api_key=TOGETHER_API_KEY
        )
        
        # Initialize memory
        self.memory = ConversationBufferMemory(
            return_messages=True,
            memory_key="chat_history",
            output_key="answer"
        )
        
        # Create the prompt template
        self.template = """<s>[INST] Based on the following context and chat history, answer the question naturally:

Context: {context}

Chat History: {chat_history}

Question: {question} [/INST]"""
        
        self.prompt = ChatPromptTemplate.from_template(self.template)
        
        # Create the chain
        self.chain = (
            {
                "context": self.retriever,
                "chat_history": lambda x: self.get_chat_history(),
                "question": RunnablePassthrough()
            }
            | self.prompt
            | self.model
            | StrOutputParser()
        )
    
    def get_chat_history(self) -> str:
        """Format chat history for the prompt"""
        messages = self.memory.load_memory_variables({})["chat_history"]
        return "\n".join([f"{m.type}: {m.content}" for m in messages])
    
    def process_response(self, response: str) -> str:
        """Clean up the response"""
        response = response.replace("[/INST]", "").replace("<s>", "").replace("</s>", "")
        return response.strip()
    
    def chat(self, message: str, history: List[Tuple[str, str]]) -> str:
        """Process a single chat message"""
        self.memory.chat_memory.add_user_message(message)
        response = self.chain.invoke(message)
        clean_response = self.process_response(response)
        self.memory.chat_memory.add_ai_message(clean_response)
        return clean_response

    def reset_chat(self) -> List[Tuple[str, str]]:
        """Reset the chat history"""
        self.memory.clear()
        return []

# Create the Gradio interface
def create_demo() -> gr.Interface:
    chatbot = ChatBot()
    
    with gr.Blocks() as demo:
        gr.Markdown("""# Knowledge Base Chatbot
        Ask questions about your documents and get informed responses!""")
        
        chatbot_interface = gr.Chatbot(
            height=600,
            show_copy_button=True,
        )
        
        with gr.Row():
            msg = gr.Textbox(
                show_label=False,
                placeholder="Type your message here...",
                container=False
            )
            submit = gr.Button("Send", variant="primary")
        
        clear = gr.Button("New Chat")
        
        def respond(message, chat_history):
            bot_message = chatbot.chat(message, chat_history)
            chat_history.append((message, bot_message))
            return "", chat_history
        
        submit.click(respond, [msg, chatbot_interface], [msg, chatbot_interface])
        msg.submit(respond, [msg, chatbot_interface], [msg, chatbot_interface])
        clear.click(lambda: chatbot.reset_chat(), None, chatbot_interface)
    
    return demo

demo = create_demo()

if __name__ == "__main__":
    demo.launch()