CXRAG / app.py
ghostai1's picture
Create app.py
eda95b4 verified
raw
history blame
4.92 kB
import gradio as gr
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
import matplotlib.pyplot as plt
import seaborn as sns
import time
import os
# Sample FAQs (embedded in script for simplicity)
faq_data = pd.DataFrame({
'question': [
'How do I reset my password?',
'What are your pricing plans?',
'How do I contact support?',
None, # Junk data (null)
'How do I reset my password?' # Duplicate
],
'answer': [
'Go to the login page, click "Forgot Password," and follow the email instructions.',
'We offer Basic ($10/month), Pro ($50/month), and Enterprise (custom).',
'Email [email protected] or call +1-800-123-4567.',
None, # Junk data
'Duplicate answer.' # Duplicate
]
})
# Data cleanup function
def clean_faqs(df):
df = df.dropna() # Remove nulls
df = df[~df['question'].duplicated()] # Remove duplicates
df = df[df['answer'].str.len() > 20] # Filter short answers
return df
# Preprocess FAQs
faq_data = clean_faqs(faq_data)
# Initialize RAG components
embedder = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = embedder.encode(faq_data['question'].tolist(), show_progress_bar=False)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings.astype(np.float32))
# RAG process
def rag_process(query, k=2):
if not query.strip() or len(query) < 5:
return "Invalid query. Please enter a valid question.", [], {}
start_time = time.perf_counter()
# Embed query
query_embedding = embedder.encode([query], show_progress_bar=False)
embed_time = time.perf_counter() - start_time
# Retrieve FAQs
start_time = time.perf_counter()
distances, indices = index.search(query_embedding.astype(np.float32), k)
retrieved_faqs = faq_data.iloc[indices[0]][['question', 'answer']].to_dict('records')
retrieval_time = time.perf_counter() - start_time
# Generate response (rule-based for free tier)
start_time = time.perf_counter()
response = retrieved_faqs[0]['answer'] if retrieved_faqs else "Sorry, I couldn't find an answer."
generation_time = time.perf_counter() - start_time
# Metrics
metrics = {
'embed_time': embed_time * 1000, # ms
'retrieval_time': retrieval_time * 1000,
'generation_time': generation_time * 1000,
'accuracy': 95.0 if retrieved_faqs else 0.0 # Simulated
}
return response, retrieved_faqs, metrics
# Plot RAG pipeline
def plot_metrics(metrics):
data = pd.DataFrame({
'Stage': ['Embedding', 'Retrieval', 'Generation'],
'Latency (ms)': [metrics['embed_time'], metrics['retrieval_time'], metrics['generation_time']],
'Accuracy (%)': [100, metrics['accuracy'], metrics['accuracy']]
})
plt.figure(figsize=(8, 5))
sns.set_style("whitegrid")
sns.set_palette("muted")
ax1 = sns.barplot(x='Stage', y='Latency (ms)', data=data, color='skyblue')
ax1.set_ylabel('Latency (ms)', color='blue')
ax1.tick_params(axis='y', labelcolor='blue')
ax2 = ax1.twinx()
sns.lineplot(x='Stage', y='Accuracy (%)', data=data, marker='o', color='red')
ax2.set_ylabel('Accuracy (%)', color='red')
ax2.tick_params(axis='y', labelcolor='red')
plt.title('RAG Pipeline: Latency and Accuracy')
plt.tight_layout()
plt.savefig('rag_plot.png')
plt.close()
return 'rag_plot.png'
# Gradio interface
def chat_interface(query):
response, retrieved_faqs, metrics = rag_process(query)
plot_path = plot_metrics(metrics)
faq_text = "\n".join([f"Q: {faq['question']}\nA: {faq['answer']}" for faq in retrieved_faqs])
cleanup_stats = f"Cleaned FAQs: {len(faq_data)} (removed {5 - len(faq_data)} junk entries)"
return response, faq_text, cleanup_stats, plot_path
# Dark theme CSS
custom_css = """
body { background-color: #2a2a2a; color: #e0e0e0; }
.gr-box { background-color: #3a3a3a; border: 1px solid #4a4a4a; }
.gr-button { background-color: #1e90ff; color: white; }
.gr-button:hover { background-color: #1c86ee; }
"""
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# Crescendo CX Bot Demo")
gr.Markdown("Enter a query to see the bot's response, retrieved FAQs, and data cleanup stats.")
with gr.Row():
query_input = gr.Textbox(label="Your Query", placeholder="e.g., How do I reset my password?")
submit_btn = gr.Button("Submit")
response_output = gr.Textbox(label="Bot Response")
faq_output = gr.Textbox(label="Retrieved FAQs")
cleanup_output = gr.Textbox(label="Data Cleanup Stats")
plot_output = gr.Image(label="RAG Pipeline Metrics")
submit_btn.click(
fn=chat_interface,
inputs=query_input,
outputs=[response_output, faq_output, cleanup_output, plot_output]
)
demo.launch()