23RAG7 / app.py
cb1716pics's picture
Upload 2 files
ced5431 verified
raw
history blame
6.98 kB
# import streamlit as st
# from generator import generate_response_from_document
# from retrieval import retrieve_documents_hybrid
# from evaluation import calculate_metrics
# #from data_processing import load_data_from_faiss
# import time
# # Page Title
# st.title("RAG7 - Real World RAG System")
# # global retrieved_documents
# # retrieved_documents = []
# # global response
# # response = ""
# # global time_taken_for_response
# # time_taken_for_response = 'N/A'
# # @st.cache_data
# # def load_data():
# # load_data_from_faiss()
# # data_status = load_data()
# # Question Section
# st.subheader("Hi, What do you want to know today?")
# question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
# # # Submit Button
# # if st.button("Submit"):
# # start_time = time.time()
# # retrieved_documents = retrieve_documents_hybrid(question, 10)
# # response = generate_response_from_document(question, retrieved_documents)
# # end_time = time.time()
# # time_taken_for_response = end_time-start_time
# # else:
# # response = ""
# # # Response Section
# # st.subheader("Response")
# # st.text_area("Generated Response:", value=response, height=150, disabled=True)
# # # Metrics Section
# # st.subheader("Metrics")
# # col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
# # with col1:
# # if st.button("Calculate Metrics"):
# # metrics = calculate_metrics(question, response, retrieved_documents, time_taken_for_response)
# # else:
# # metrics = ""
# # with col2:
# # st.text_area("Metrics:", value=metrics, height=100, disabled=True)
# if "retrieved_documents" not in st.session_state:
# st.session_state.retrieved_documents = []
# if "response" not in st.session_state:
# st.session_state.response = ""
# if "time_taken_for_response" not in st.session_state:
# st.session_state.time_taken_for_response = "N/A"
# # Submit Button
# if st.button("Submit"):
# start_time = time.time()
# st.session_state.retrieved_documents = retrieve_documents_hybrid(question, 10)
# st.session_state.response = generate_response_from_document(question, st.session_state.retrieved_documents)
# end_time = time.time()
# st.session_state.time_taken_for_response = end_time - start_time
# # Display stored response
# st.subheader("Response")
# st.text_area("Generated Response:", value=st.session_state.response, height=150, disabled=True)
# col1, col2 = st.columns([1, 3]) # Creating two columns for button and metrics display
# # Calculate Metrics Button
# with col1:
# if st.button("Calculate Metrics"):
# metrics = calculate_metrics(question, st.session_state.response, st.session_state.retrieved_documents, st.session_state.time_taken_for_response)
# else:
# metrics = {}
# with col2:
# #st.text_area("Metrics:", value=metrics, height=100, disabled=True)
# st.json(metrics)
import streamlit as st
import plotly.express as px
from datasets import load_dataset, Dataset, DatasetDict
from generator import generate_response_from_document
from retrieval import retrieve_documents_hybrid
from evaluation import calculate_metrics
import time
# Hugging Face Dataset Details
HF_DATASET_REPO = "cb1716pics/23RAG7_recent_questions" # Hugging Face repo
# Load Dataset from Hugging Face
@st.cache_resource
def load_hf_dataset():
try:
return load_dataset(HF_DATASET_REPO)
except:
return DatasetDict({"recent": Dataset.from_dict({"question": [], "response": [], "metrics": []})})
dataset = load_hf_dataset()
# Function to Save Data to Hugging Face Dataset
def save_to_hf_dataset(question, response, metrics):
global dataset
new_data = {
"question": [question],
"response": [response],
"metrics": [metrics]
}
# Convert existing dataset to a list and append new data
dataset_dict = dataset["recent"].to_dict()
for key in new_data.keys():
dataset_dict[key] = dataset_dict.get(key, []) + new_data[key]
# Keep only the last 10 entries
for key in dataset_dict.keys():
dataset_dict[key] = dataset_dict[key][-10:]
# Convert back to dataset and push to Hugging Face
dataset["recent"] = Dataset.from_dict(dataset_dict)
dataset.push_to_hub(HF_DATASET_REPO)
# Streamlit UI
st.title("πŸ” RAG7 - Real World RAG System")
# Sidebar - Recent Questions
st.sidebar.header("πŸ“Œ Recent Questions")
if len(dataset["recent"]) > 0:
for q in dataset["recent"]["question"][-10:]:
st.sidebar.write(f"πŸ”Ή {q}")
# Sidebar - Analytics with Graph
st.sidebar.header("πŸ“Š Analytics Overview")
if len(dataset["recent"]) > 0:
# Extract recent metrics for visualization
metrics_data = dataset["recent"]["metrics"][-10:]
metrics_keys = ["context_relevance", "context_utilization", "completeness", "adherence"]
# Prepare a dictionary for graphing
graph_data = {key: [m[key] for m in metrics_data] for key in metrics_keys}
graph_data["Question #"] = list(range(1, len(metrics_data) + 1))
# Convert to DataFrame for Plotly
import pandas as pd
df = pd.DataFrame(graph_data)
# Plot Metrics Over Time
fig = px.line(df, x="Question #", y=metrics_keys,
labels={"value": "Score", "variable": "Metric"},
title="πŸ“ˆ Model Performance Over Recent Questions")
st.sidebar.plotly_chart(fig, use_container_width=True)
# Evaluate Button
if st.sidebar.button("⚑ Evaluate RAG Model"):
st.sidebar.success("βœ… Model Evaluation Triggered!")
# Main Section - User Input
st.subheader("πŸ’¬ Ask a Question")
question = st.text_area("Enter your question:", placeholder="Type your question here...", height=100)
# Submit Button
if st.button("πŸš€ Submit"):
start_time = time.time()
retrieved_documents = retrieve_documents_hybrid(question, 10)
response = generate_response_from_document(question, retrieved_documents)
end_time = time.time()
time_taken_for_response = end_time - start_time
# Calculate Metrics
metrics = calculate_metrics(question, response, retrieved_documents, time_taken_for_response)
# Save Data
save_to_hf_dataset(question, response, metrics)
# Display Response
st.subheader("πŸ’‘ Response")
st.text_area("Generated Response:", value=response, height=150, disabled=True)
# Display Metrics with Bar Chart
st.subheader("πŸ“Š Metrics")
st.json(metrics)
# Plot Bar Chart for Metrics
metric_df = pd.DataFrame({"Metric": list(metrics.keys()), "Score": list(metrics.values())})
fig2 = px.bar(metric_df, x="Metric", y="Score", title="πŸ“Š Current Query Metrics")
st.plotly_chart(fig2, use_container_width=True)