Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import time | |
| import os | |
| import gc | |
| import torch | |
| from src.data_processing import load_huggingface_faq_data, load_faq_data, preprocess_faq | |
| from src.embedding import FAQEmbedder | |
| from src.llm_response import ResponseGenerator | |
| from src.utils import time_function, format_memory_stats | |
| # Set page title and layout | |
| st.set_page_config( | |
| page_title="E-Commerce FAQ Chatbot", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Memory optimization: Force garbage collection before starting | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def initialize_components(use_huggingface: bool = True, model_name: str = "mistralai/Mistral-7B-Instruct-v0.1"): | |
| """Initialize all components of the RAG system with memory optimization""" | |
| # Step 1: Load and preprocess FAQ data | |
| if use_huggingface: | |
| faqs = load_huggingface_faq_data("NebulaByte/E-Commerce_FAQs") | |
| else: | |
| data_path = os.path.join("data", "faq_data.csv") | |
| faqs = load_faq_data(data_path) | |
| processed_faqs = preprocess_faq(faqs) | |
| # Step 2: Initialize and create embeddings | |
| # Use smaller batch size for memory efficiency | |
| embedder = FAQEmbedder() | |
| embedder.create_embeddings(processed_faqs, batch_size=32) | |
| # Clear memory before loading the LLM | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Step 3: Initialize response generator | |
| response_generator = ResponseGenerator(model_name=model_name) | |
| return embedder, response_generator, len(processed_faqs) | |
| def main(): | |
| st.title("E-Commerce Customer Support FAQ Chatbot") | |
| st.subheader("Ask any question about your orders, shipping, returns, or any other e-commerce related queries") | |
| # Sidebar configuration | |
| st.sidebar.title("Configuration") | |
| use_huggingface = st.sidebar.checkbox("Use Hugging Face Dataset", value=True) | |
| # Model options - include smaller models by default | |
| model_options = { | |
| "Phi-2 (Recommended for 8GB GPU)": "microsoft/phi-2", | |
| "TinyLlama-1.1B (Smallest, fastest)": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
| "Mistral-7B (Requires 4-bit quantization)": "mistralai/Mistral-7B-Instruct-v0.1" | |
| } | |
| # Default to Phi-2 for 8-11GB GPU | |
| selected_model = st.sidebar.selectbox("Select LLM Model", list(model_options.keys()), index=0) | |
| model_name = model_options[selected_model] | |
| # Memory usage monitoring | |
| if st.sidebar.checkbox("Show Memory Usage", value=True): | |
| st.sidebar.subheader("Memory Usage") | |
| memory_stats = format_memory_stats() | |
| for key, value in memory_stats.items(): | |
| st.sidebar.text(f"{key}: {value}") | |
| # Initialize session state for chat history if it doesn't exist | |
| if "chat_history" not in st.session_state: | |
| st.session_state.chat_history = [] | |
| # Initialize RAG components (only once) | |
| if "system_initialized" not in st.session_state or st.sidebar.button("Reload System"): | |
| with st.spinner("Initializing system components... This may take a few minutes."): | |
| st.session_state.embedder, st.session_state.response_generator, num_faqs = initialize_components( | |
| use_huggingface=use_huggingface, | |
| model_name=model_name | |
| ) | |
| st.session_state.system_initialized = True | |
| st.sidebar.success(f"System initialized with {num_faqs} FAQs!") | |
| # Chat interface | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| # Display chat history | |
| st.subheader("Conversation") | |
| chat_container = st.container(height=400) | |
| with chat_container: | |
| for i, message in enumerate(st.session_state.chat_history): | |
| if message["role"] == "user": | |
| st.markdown(f"**You**: {message['content']}") | |
| else: | |
| st.markdown(f"**Bot**: {message['content']}") | |
| if i < len(st.session_state.chat_history) - 1: | |
| st.markdown("---") | |
| # Chat input | |
| with st.form(key="chat_form"): | |
| user_query = st.text_input("Type your question:", key="user_input", | |
| placeholder="e.g., How do I track my order?") | |
| submit_button = st.form_submit_button("Ask") | |
| with col2: | |
| if st.session_state.get("system_initialized", False): | |
| # Show FAQ metadata and information | |
| st.subheader("Retrieved Information") | |
| info_container = st.container(height=500) | |
| with info_container: | |
| if "current_faqs" in st.session_state: | |
| for i, faq in enumerate(st.session_state.get("current_faqs", [])): | |
| st.markdown(f"**Relevant FAQ #{i+1}**") | |
| st.markdown(f"**Q**: {faq['question']}") | |
| # Limit answer length to save UI memory | |
| st.markdown(f"**A**: {faq['answer'][:150]}..." if len(faq['answer']) > 150 else f"**A**: {faq['answer']}") | |
| st.markdown(f"*Similarity Score*: {faq['similarity']:.2f}") | |
| if 'category' in faq and faq['category']: | |
| st.markdown(f"*Category*: {faq['category']}") | |
| st.markdown("---") | |
| else: | |
| st.markdown("Ask a question to see relevant FAQs here.") | |
| # Performance metrics in the sidebar | |
| if "retrieval_time" in st.session_state and "generation_time" in st.session_state: | |
| st.sidebar.subheader("Performance Metrics") | |
| st.sidebar.markdown(f"Retrieval time: {st.session_state.retrieval_time:.2f} seconds") | |
| st.sidebar.markdown(f"Response generation: {st.session_state.generation_time:.2f} seconds") | |
| st.sidebar.markdown(f"Total time: {st.session_state.retrieval_time + st.session_state.generation_time:.2f} seconds") | |
| # Process user query | |
| if submit_button and user_query: | |
| # Add user query to chat history | |
| st.session_state.chat_history.append({"role": "user", "content": user_query}) | |
| # Process query | |
| with st.spinner("Thinking..."): | |
| # Free memory before processing | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Step 1: Retrieve relevant FAQs | |
| start_time = time.time() | |
| relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(user_query) | |
| retrieval_time = time.time() - start_time | |
| # Step 2: Generate response | |
| start_time = time.time() | |
| response = st.session_state.response_generator.generate_response(user_query, relevant_faqs) | |
| generation_time = time.time() - start_time | |
| # Store metrics and retrieved FAQs | |
| st.session_state.retrieval_time = retrieval_time | |
| st.session_state.generation_time = generation_time | |
| st.session_state.current_faqs = relevant_faqs | |
| # Step 3: Add response to chat history | |
| st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
| # Free memory after processing | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Rerun to display the updated chat history | |
| st.experimental_rerun() | |
| # Add sample questions at the bottom | |
| st.subheader("Sample Questions") | |
| sample_questions = [ | |
| "How do I track my order?", | |
| "What should I do if my delivery is delayed?", | |
| "How do I return a product?", | |
| "Can I cancel my order after placing it?", | |
| "How quickly will my order be delivered?", | |
| "Why can't I track my order yet?" | |
| ] | |
| # Use two columns instead of three to reduce memory usage | |
| cols = st.columns(2) | |
| for i, question in enumerate(sample_questions): | |
| col_idx = i % 2 | |
| if cols[col_idx].button(question, key=f"sample_{i}"): | |
| # Clear the text input and set the sample question | |
| st.session_state.user_input = question | |
| # Simulate form submission | |
| st.session_state.chat_history.append({"role": "user", "content": question}) | |
| # Process query (similar to above) | |
| with st.spinner("Thinking..."): | |
| # Free memory before processing | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Step 1: Retrieve relevant FAQs | |
| start_time = time.time() | |
| relevant_faqs = st.session_state.embedder.retrieve_relevant_faqs(question) | |
| retrieval_time = time.time() - start_time | |
| # Step 2: Generate response | |
| start_time = time.time() | |
| response = st.session_state.response_generator.generate_response(question, relevant_faqs) | |
| generation_time = time.time() - start_time | |
| # Store metrics and retrieved FAQs | |
| st.session_state.retrieval_time = retrieval_time | |
| st.session_state.generation_time = generation_time | |
| st.session_state.current_faqs = relevant_faqs | |
| # Step 3: Add response to chat history | |
| st.session_state.chat_history.append({"role": "assistant", "content": response}) | |
| # Free memory after processing | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Rerun to display the updated chat history | |
| st.experimental_rerun() | |
| if __name__ == "__main__": | |
| main() |