Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import os | |
from langchain_community.vectorstores import FAISS | |
from langchain_together import TogetherEmbeddings | |
from operator import itemgetter | |
from langchain.memory import ConversationBufferMemory | |
from langchain.schema import format_document | |
from typing import List, Tuple | |
# Environment variables for API keys | |
TOGETHER_API_KEY = os.getenv('TOGETHER_API_KEY') | |
class ChatBot: | |
def __init__(self): | |
# Initialize embeddings | |
self.embeddings = TogetherEmbeddings( | |
model="togethercomputer/m2-bert-80M-8k-retrieval", | |
together_api_key=TOGETHER_API_KEY | |
) | |
# Load the pre-created FAISS index with embeddings | |
self.vectorstore = FAISS.load_local( | |
"faiss_index", | |
embeddings=self.embeddings | |
) | |
self.retriever = self.vectorstore.as_retriever() | |
# Initialize the model | |
self.model = Together( | |
model="meta-llama/Llama-3.3-70B-Instruct-Turbo", | |
temperature=0.7, | |
max_tokens=128, | |
top_k=50, | |
together_api_key=TOGETHER_API_KEY | |
) | |
# Initialize memory | |
self.memory = ConversationBufferMemory( | |
return_messages=True, | |
memory_key="chat_history", | |
output_key="answer" | |
) | |
# Create the prompt template | |
self.template = """<s>[INST] Based on the following context and chat history, answer the question naturally: | |
Context: {context} | |
Chat History: {chat_history} | |
Question: {question} [/INST]""" | |
self.prompt = ChatPromptTemplate.from_template(self.template) | |
# Create the chain | |
self.chain = ( | |
{ | |
"context": self.retriever, | |
"chat_history": lambda x: self.get_chat_history(), | |
"question": RunnablePassthrough() | |
} | |
| self.prompt | |
| self.model | |
| StrOutputParser() | |
) | |
def get_chat_history(self) -> str: | |
"""Format chat history for the prompt""" | |
messages = self.memory.load_memory_variables({})["chat_history"] | |
return "\n".join([f"{m.type}: {m.content}" for m in messages]) | |
def process_response(self, response: str) -> str: | |
"""Clean up the response""" | |
response = response.replace("[/INST]", "").replace("<s>", "").replace("</s>", "") | |
return response.strip() | |
def chat(self, message: str, history: List[Tuple[str, str]]) -> str: | |
"""Process a single chat message""" | |
self.memory.chat_memory.add_user_message(message) | |
response = self.chain.invoke(message) | |
clean_response = self.process_response(response) | |
self.memory.chat_memory.add_ai_message(clean_response) | |
return clean_response | |
def reset_chat(self) -> List[Tuple[str, str]]: | |
"""Reset the chat history""" | |
self.memory.clear() | |
return [] | |
# Create the Gradio interface | |
def create_demo() -> gr.Interface: | |
chatbot = ChatBot() | |
with gr.Blocks() as demo: | |
gr.Markdown("""# Knowledge Base Chatbot | |
Ask questions about your documents and get informed responses!""") | |
chatbot_interface = gr.Chatbot( | |
height=600, | |
show_copy_button=True, | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
show_label=False, | |
placeholder="Type your message here...", | |
container=False | |
) | |
submit = gr.Button("Send", variant="primary") | |
clear = gr.Button("New Chat") | |
def respond(message, chat_history): | |
bot_message = chatbot.chat(message, chat_history) | |
chat_history.append((message, bot_message)) | |
return "", chat_history | |
submit.click(respond, [msg, chatbot_interface], [msg, chatbot_interface]) | |
msg.submit(respond, [msg, chatbot_interface], [msg, chatbot_interface]) | |
clear.click(lambda: chatbot.reset_chat(), None, chatbot_interface) | |
return demo | |
demo = create_demo() | |
if __name__ == "__main__": | |
demo.launch() |