beeline / app.py
OrifjonKenjayev's picture
Update app.py
fe64323 verified
raw
history blame
4.18 kB
# app.py
import gradio as gr
import os
from langchain_community.vectorstores import FAISS
from langchain_together import TogetherEmbeddings
from operator import itemgetter
from langchain.memory import ConversationBufferMemory
from langchain.schema import format_document
from typing import List, Tuple
# Environment variables for API keys
TOGETHER_API_KEY = os.getenv('TOGETHER_API_KEY')
class ChatBot:
def __init__(self):
# Initialize embeddings
self.embeddings = TogetherEmbeddings(
model="togethercomputer/m2-bert-80M-8k-retrieval",
together_api_key=TOGETHER_API_KEY
)
# Load the pre-created FAISS index with embeddings
self.vectorstore = FAISS.load_local(
"faiss_index",
embeddings=self.embeddings
)
self.retriever = self.vectorstore.as_retriever()
# Initialize the model
self.model = Together(
model="meta-llama/Llama-3.3-70B-Instruct-Turbo",
temperature=0.7,
max_tokens=128,
top_k=50,
together_api_key=TOGETHER_API_KEY
)
# Initialize memory
self.memory = ConversationBufferMemory(
return_messages=True,
memory_key="chat_history",
output_key="answer"
)
# Create the prompt template
self.template = """<s>[INST] Based on the following context and chat history, answer the question naturally:
Context: {context}
Chat History: {chat_history}
Question: {question} [/INST]"""
self.prompt = ChatPromptTemplate.from_template(self.template)
# Create the chain
self.chain = (
{
"context": self.retriever,
"chat_history": lambda x: self.get_chat_history(),
"question": RunnablePassthrough()
}
| self.prompt
| self.model
| StrOutputParser()
)
def get_chat_history(self) -> str:
"""Format chat history for the prompt"""
messages = self.memory.load_memory_variables({})["chat_history"]
return "\n".join([f"{m.type}: {m.content}" for m in messages])
def process_response(self, response: str) -> str:
"""Clean up the response"""
response = response.replace("[/INST]", "").replace("<s>", "").replace("</s>", "")
return response.strip()
def chat(self, message: str, history: List[Tuple[str, str]]) -> str:
"""Process a single chat message"""
self.memory.chat_memory.add_user_message(message)
response = self.chain.invoke(message)
clean_response = self.process_response(response)
self.memory.chat_memory.add_ai_message(clean_response)
return clean_response
def reset_chat(self) -> List[Tuple[str, str]]:
"""Reset the chat history"""
self.memory.clear()
return []
# Create the Gradio interface
def create_demo() -> gr.Interface:
chatbot = ChatBot()
with gr.Blocks() as demo:
gr.Markdown("""# Knowledge Base Chatbot
Ask questions about your documents and get informed responses!""")
chatbot_interface = gr.Chatbot(
height=600,
show_copy_button=True,
)
with gr.Row():
msg = gr.Textbox(
show_label=False,
placeholder="Type your message here...",
container=False
)
submit = gr.Button("Send", variant="primary")
clear = gr.Button("New Chat")
def respond(message, chat_history):
bot_message = chatbot.chat(message, chat_history)
chat_history.append((message, bot_message))
return "", chat_history
submit.click(respond, [msg, chatbot_interface], [msg, chatbot_interface])
msg.submit(respond, [msg, chatbot_interface], [msg, chatbot_interface])
clear.click(lambda: chatbot.reset_chat(), None, chatbot_interface)
return demo
demo = create_demo()
if __name__ == "__main__":
demo.launch()