faq-rag-chatbot / app.py
Techbite's picture
initial commit
26d1a81
raw
history blame
9.98 kB
import streamlit as st
import time
import os
import gc
import torch
from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq
from src.embedding import FAQEmbedder
from src.llm_response import ResponseGenerator
from src.utils import time_function, format_memory_stats
# Set page title and layout
st.set_page_config(
page_title="E-Commerce FAQ Chatbot",
layout="wide",
initial_sidebar_state="expanded"
)
# Memory optimization: Force garbage collection before starting
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@time_function
def initialize_components(use_huggingface: bool = True, model_name: str = "mistralai/Mistral-7B-Instruct-v0.1"):
"""Initialize all components of the RAG system with memory optimization"""
# Step 1: Load and preprocess FAQ data
if use_huggingface:
faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs")
else:
data_path = os.path.join("data", "faq_data.csv")
faqs = load_faq_data(data_path)
processed_faqs = preprocess_faq(faqs)
# Step 2: Initialize and create embeddings
# Use smaller batch size for memory efficiency
embedder = FAQEmbedder()
embedder.create_embeddings(processed_faqs, batch_size=32)
# Clear memory before loading the LLM
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Step 3: Initialize response generator
response_generator = ResponseGenerator(model_name=model_name)
return embedder, response_generator, len(processed_faqs)
def main():
st.title("E-Commerce Customer Support FAQ Chatbot")
st.subheader("Ask any question about your orders, shipping, returns, or any other e-commerce related queries")
# Sidebar configuration
st.sidebar.title("Configuration")
use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True)
# Model options - include smaller models by default
model_options = {
"Phi-2 (Recommended for 8GB GPU)": "microsoft/phi-2",
"TinyLlama-1.1B (Smallest, fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"Mistral-7B (Requires 4-bit quantization)": "mistralai/Mistral-7B-Instruct-v0.1"
}
# Default to Phi-2 for 8-11GB GPU
selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0)
model_name = model_options[selected_model]
# Memory usage monitoring
if st.sidebar.checkbox("Show Memory Usage", value=True):
st.sidebar.subheader("Memory Usage")
memory_stats = format_memory_stats()
for key, value in memory_stats.items():
st.sidebar.text(f"{key}: {value}")
# Initialize session state for chat history if it doesn't exist
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Initialize RAG components (only once)
if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"):
with st.spinner("Initializing system components... This may take a few minutes."):
st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components(
use_huggingface=use_huggingface,
model_name=model_name
)
st.session_state.system_initialized = True
st.sidebar.success(f"System initialized with {num_faqs} FAQs!")
# Chat interface
col1, col2 = st.columns([2, 1])
with col1:
# Display chat history
st.subheader("Conversation")
chat_container = st.container(height=400)
with chat_container:
for i, message in enumerate(st.session_state.chat_history):
if message["role"] == "user":
st.markdown(f"**You**: {message['content']}")
else:
st.markdown(f"**Bot**: {message['content']}")
if i < len(st.session_state.chat_history) - 1:
st.markdown("---")
# Chat input
with st.form(key="chat_form"):
user_query = st.text_input("Type your question:", key="user_input",
placeholder="e.g., How do I track my order?")
submit_button = st.form_submit_button("Ask")
with col2:
if st.session_state.get("system_initialized", False):
# Show FAQ metadata and information
st.subheader("Retrieved Information")
info_container = st.container(height=500)
with info_container:
if "current_faqs" in st.session_state:
for i, faq in enumerate(st.session_state.get("current_faqs", [])):
st.markdown(f"**Relevant FAQ #{i+1}**")
st.markdown(f"**Q**: {faq['question']}")
# Limit answer length to save UI memory
st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}")
st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}")
if 'category' in faq and faq['category']:
st.markdown(f"*Category*: {faq['category']}")
st.markdown("---")
else:
st.markdown("Ask a question to see relevant FAQs here.")
# Performance metrics in the sidebar
if "retrieval_time" in st.session_state and "generation_time" in st.session_state:
st.sidebar.subheader("Performance Metrics")
st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds")
st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds")
st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds")
# Process user query
if submit_button and user_query:
# Add user query to chat history
st.session_state.chat_history.append({"role": "user", "content": user_query})
# Process query
with st.spinner("Thinking..."):
# Free memory before processing
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Step 1: Retrieve relevant FAQs
start_time = time.time()
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query)
retrieval_time = time.time() - start_time
# Step 2: Generate response
start_time = time.time()
response = st.session_state.response_generator.generate_response(user_query, relevant_faqs)
generation_time = time.time() - start_time
# Store metrics and retrieved FAQs
st.session_state.retrieval_time = retrieval_time
st.session_state.generation_time = generation_time
st.session_state.current_faqs = relevant_faqs
# Step 3: Add response to chat history
st.session_state.chat_history.append({"role": "assistant", "content": response})
# Free memory after processing
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Rerun to display the updated chat history
st.experimental_rerun()
# Add sample questions at the bottom
st.subheader("Sample Questions")
sample_questions = [
"How do I track my order?",
"What should I do if my delivery is delayed?",
"How do I return a product?",
"Can I cancel my order after placing it?",
"How quickly will my order be delivered?",
"Why can't I track my order yet?"
]
# Use two columns instead of three to reduce memory usage
cols = st.columns(2)
for i, question in enumerate(sample_questions):
col_idx = i % 2
if cols[col_idx].button(question, key=f"sample_{i}"):
# Clear the text input and set the sample question
st.session_state.user_input = question
# Simulate form submission
st.session_state.chat_history.append({"role": "user", "content": question})
# Process query (similar to above)
with st.spinner("Thinking..."):
# Free memory before processing
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Step 1: Retrieve relevant FAQs
start_time = time.time()
relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question)
retrieval_time = time.time() - start_time
# Step 2: Generate response
start_time = time.time()
response = st.session_state.response_generator.generate_response(question, relevant_faqs)
generation_time = time.time() - start_time
# Store metrics and retrieved FAQs
st.session_state.retrieval_time = retrieval_time
st.session_state.generation_time = generation_time
st.session_state.current_faqs = relevant_faqs
# Step 3: Add response to chat history
st.session_state.chat_history.append({"role": "assistant", "content": response})
# Free memory after processing
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Rerun to display the updated chat history
st.experimental_rerun()
if __name__ == "__main__":
main()