Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import wikipedia | |
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import plotly.graph_objects as go | |
import plotly.express as px | |
from plotly.subplots import make_subplots | |
import time | |
import pandas as pd | |
import warnings | |
warnings.filterwarnings("ignore") | |
# Global variables to store models and data | |
embedding_model = None | |
qa_pipeline = None | |
chunks = None | |
embeddings = None | |
index = None | |
document = None | |
def load_models(): | |
"""Load and cache the ML models""" | |
global embedding_model, qa_pipeline | |
if embedding_model is None: | |
print("π€ Loading embedding model...") | |
embedding_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2") | |
print("π€ Loading QA model...") | |
qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2") | |
qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2") | |
qa_pipeline = pipeline("question-answering", model=qa_model, tokenizer=qa_tokenizer) | |
print("β Models loaded successfully!") | |
return "β Models are ready!" | |
def get_wikipedia_content(topic): | |
"""Fetch Wikipedia content""" | |
try: | |
page = wikipedia.page(topic) | |
return page.content, f"β Successfully fetched '{topic}' article" | |
except wikipedia.exceptions.PageError: | |
return None, f"β Page '{topic}' not found. Please try a different topic." | |
except wikipedia.exceptions.DisambiguationError as e: | |
return None, f"β οΈ Ambiguous topic. Try one of these: {', '.join(e.options[:5])}" | |
def split_text(text, chunk_size=256, chunk_overlap=20): | |
"""Split text into overlapping chunks""" | |
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2") | |
# Split into sentences first | |
sentences = text.split('. ') | |
chunks = [] | |
current_chunk = "" | |
for sentence in sentences: | |
test_chunk = current_chunk + ". " + sentence if current_chunk else sentence | |
test_tokens = tokenizer.tokenize(test_chunk) | |
if len(test_tokens) > chunk_size: | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
# Add overlap | |
if chunk_overlap > 0 and chunks: | |
overlap_tokens = tokenizer.tokenize(current_chunk) | |
if len(overlap_tokens) > chunk_overlap: | |
overlap_start = len(overlap_tokens) - chunk_overlap | |
overlap_text = tokenizer.convert_tokens_to_string(overlap_tokens[overlap_start:]) | |
current_chunk = overlap_text + ". " + sentence | |
else: | |
current_chunk = sentence | |
else: | |
current_chunk = sentence | |
else: | |
current_chunk = sentence | |
else: | |
current_chunk = test_chunk | |
if current_chunk.strip(): | |
chunks.append(current_chunk.strip()) | |
return chunks | |
def process_article(topic, chunk_size, chunk_overlap): | |
"""Process Wikipedia article into chunks and embeddings""" | |
global chunks, embeddings, index, document | |
if not topic.strip(): | |
return "β οΈ Please enter a topic first!", None, "" | |
# Load models first | |
load_models() | |
# Fetch content | |
document, message = get_wikipedia_content(topic) | |
if document is None: | |
return message, None, "" | |
# Process text | |
chunks = split_text(document, int(chunk_size), int(chunk_overlap)) | |
# Create embeddings | |
embeddings = embedding_model.encode(chunks) | |
# Build FAISS index | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(np.array(embeddings)) | |
# Create summary stats | |
chunk_lengths = [len(chunk.split()) for chunk in chunks] | |
summary = f""" | |
π **Processing Summary:** | |
- **Total chunks**: {len(chunks)} | |
- **Embedding dimension**: {dimension} | |
- **Average chunk length**: {np.mean(chunk_lengths):.1f} words | |
- **Min/Max chunk length**: {min(chunk_lengths)}/{max(chunk_lengths)} words | |
- **Document length**: {len(document.split())} words | |
β Ready for questions! | |
""" | |
return f"β Successfully processed '{topic}' into {len(chunks)} chunks!", create_chunk_visualization(), summary | |
def create_chunk_visualization(): | |
"""Create chunk length distribution plot""" | |
if chunks is None: | |
return None | |
chunk_lengths = [len(chunk.split()) for chunk in chunks] | |
fig = make_subplots( | |
rows=1, cols=2, | |
subplot_titles=("π Chunk Length Distribution", "π Statistical Summary"), | |
specs=[[{"type": "bar"}, {"type": "box"}]] | |
) | |
# Histogram | |
fig.add_trace( | |
go.Histogram(x=chunk_lengths, nbinsx=15, name="Distribution", | |
marker_color="skyblue", opacity=0.7), | |
row=1, col=1 | |
) | |
# Box plot | |
fig.add_trace( | |
go.Box(y=chunk_lengths, name="Statistics", | |
marker_color="lightgreen", boxmean=True), | |
row=1, col=2 | |
) | |
fig.update_layout(height=400, showlegend=False, title="π Chunk Analysis") | |
return fig | |
def answer_question(question, k_retrieval): | |
"""Answer question using RAG pipeline""" | |
global chunks, embeddings, index, qa_pipeline | |
if chunks is None or index is None: | |
return "β οΈ Please process an article first!", None, "", "" | |
if not question.strip(): | |
return "β οΈ Please enter a question!", None, "", "" | |
# Get query embedding | |
query_embedding = embedding_model.encode([question]) | |
# Search | |
distances, indices = index.search(np.array(query_embedding), int(k_retrieval)) | |
retrieved_chunks = [chunks[i] for i in indices[0]] | |
# Generate answer | |
context = " ".join(retrieved_chunks) | |
answer = qa_pipeline(question=question, context=context) | |
# Format results | |
confidence = answer['score'] | |
# Determine confidence level | |
if confidence >= 0.8: | |
confidence_emoji = "π’" | |
confidence_text = "Very High" | |
elif confidence >= 0.6: | |
confidence_emoji = "π΅" | |
confidence_text = "High" | |
elif confidence >= 0.4: | |
confidence_emoji = "π‘" | |
confidence_text = "Medium" | |
else: | |
confidence_emoji = "π΄" | |
confidence_text = "Low" | |
# Format answer | |
formatted_answer = f""" | |
π€ **Answer**: {answer['answer']} | |
{confidence_emoji} **Confidence**: {confidence:.1%} ({confidence_text}) | |
π **Answer Length**: {len(answer['answer'])} characters | |
π **Chunks Used**: {len(retrieved_chunks)} | |
""" | |
# Format retrieved chunks | |
retrieved_text = "π **Retrieved Context Chunks:**\n\n" | |
for i, chunk in enumerate(retrieved_chunks): | |
similarity = 1 / (1 + distances[0][i]) | |
retrieved_text += f"**Chunk {i+1}** (Similarity: {similarity:.3f}):\n{chunk}\n\n---\n\n" | |
# Create similarity visualization | |
similarity_scores = 1 / (1 + distances[0]) | |
similarity_plot = create_similarity_plot(similarity_scores) | |
return formatted_answer, similarity_plot, retrieved_text, create_confidence_gauge(confidence) | |
def create_similarity_plot(similarity_scores): | |
"""Create similarity scores bar chart""" | |
fig = go.Figure(data=[ | |
go.Bar(x=[f"Rank {i+1}" for i in range(len(similarity_scores))], | |
y=similarity_scores, | |
marker_color=['gold', 'silver', '#CD7F32'][:len(similarity_scores)], | |
text=[f'{score:.3f}' for score in similarity_scores], | |
textposition='auto') | |
]) | |
fig.update_layout( | |
title="π― Retrieved Chunks Similarity Scores", | |
xaxis_title="Retrieved Chunk Rank", | |
yaxis_title="Similarity Score", | |
height=400 | |
) | |
return fig | |
def create_confidence_gauge(confidence): | |
"""Create confidence gauge visualization""" | |
fig = go.Figure(go.Indicator( | |
mode = "gauge+number+delta", | |
value = confidence * 100, | |
domain = {'x': [0, 1], 'y': [0, 1]}, | |
title = {'text': "π― Answer Confidence (%)"}, | |
delta = {'reference': 80}, | |
gauge = { | |
'axis': {'range': [None, 100]}, | |
'bar': {'color': "darkblue"}, | |
'steps': [ | |
{'range': [0, 20], 'color': "red"}, | |
{'range': [20, 40], 'color': "orange"}, | |
{'range': [40, 60], 'color': "yellow"}, | |
{'range': [60, 80], 'color': "lightgreen"}, | |
{'range': [80, 100], 'color': "green"} | |
], | |
'threshold': { | |
'line': {'color': "black", 'width': 4}, | |
'thickness': 0.75, | |
'value': 90 | |
} | |
} | |
)) | |
fig.update_layout(height=400) | |
return fig | |
def clear_data(): | |
"""Clear all processed data""" | |
global chunks, embeddings, index, document | |
chunks = None | |
embeddings = None | |
index = None | |
document = None | |
return "ποΈ Data cleared! Ready for new article.", None, "", "", None, None, "" | |
# Create Gradio interface optimized for Hugging Face Spaces | |
def create_interface(): | |
"""Create the main Gradio interface""" | |
with gr.Blocks( | |
title="π RAG Pipeline For LLMs", | |
theme=gr.themes.Soft(), | |
) as interface: | |
# Header | |
gr.Markdown(""" | |
# π RAG Pipeline For LLMs π | |
<div style="text-align: center; color: #666; margin-bottom: 2rem;"> | |
An intelligent Q&A system powered by π€ Hugging Face, π Wikipedia, and β‘ FAISS vector search | |
</div> | |
""") | |
with gr.Tab("π Article Processing"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Markdown("### π Step 1: Configure & Process Article") | |
topic_input = gr.Textbox( | |
label="π Wikipedia Topic", | |
placeholder="e.g., Artificial Intelligence, Climate Change, Python Programming", | |
info="Enter any topic available on Wikipedia" | |
) | |
with gr.Row(): | |
chunk_size = gr.Slider( | |
label="π Chunk Size (tokens)", | |
minimum=128, | |
maximum=512, | |
value=256, | |
step=32, | |
info="Larger chunks = more context, smaller chunks = more precision" | |
) | |
chunk_overlap = gr.Slider( | |
label="π Chunk Overlap (tokens)", | |
minimum=10, | |
maximum=50, | |
value=20, | |
step=5, | |
info="Overlap helps maintain context between chunks" | |
) | |
process_btn = gr.Button("π Fetch & Process Article", variant="primary", size="lg") | |
processing_status = gr.Textbox( | |
label="π Processing Status", | |
interactive=False | |
) | |
with gr.Column(scale=1): | |
processing_summary = gr.Markdown("### π Processing Summary\n*Process an article to see statistics*") | |
chunk_plot = gr.Plot(label="π Chunk Analysis Visualization") | |
with gr.Tab("β Question Answering"): | |
with gr.Row(): | |
with gr.Column(scale=2): | |
gr.Markdown("### π― Step 2: Ask Your Question") | |
question_input = gr.Textbox( | |
label="β Your Question", | |
placeholder="e.g., What is the main concept? How does it work?", | |
info="Ask any question about the processed article" | |
) | |
k_retrieval = gr.Slider( | |
label="π Number of Chunks to Retrieve", | |
minimum=1, | |
maximum=10, | |
value=3, | |
step=1, | |
info="More chunks = broader context, fewer chunks = more focused" | |
) | |
answer_btn = gr.Button("π― Get Answer", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
gr.Markdown("### π‘ Tips\n- Process an article first\n- Ask specific questions\n- Adjust retrieval count for better results") | |
answer_output = gr.Markdown(label="π€ Generated Answer") | |
with gr.Row(): | |
similarity_plot = gr.Plot(label="π― Similarity Scores") | |
confidence_gauge = gr.Plot(label="π Confidence Meter") | |
with gr.Tab("π Retrieved Context"): | |
retrieved_chunks = gr.Markdown( | |
label="π Retrieved Chunks", | |
value="*Ask a question to see retrieved context chunks*" | |
) | |
# Event handlers | |
process_btn.click( | |
fn=process_article, | |
inputs=[topic_input, chunk_size, chunk_overlap], | |
outputs=[processing_status, chunk_plot, processing_summary] | |
) | |
answer_btn.click( | |
fn=answer_question, | |
inputs=[question_input, k_retrieval], | |
outputs=[answer_output, similarity_plot, retrieved_chunks, confidence_gauge] | |
) | |
# Footer | |
gr.Markdown(""" | |
--- | |
<div style="text-align: center; color: #666; padding: 1rem;"> | |
π RAG Pipeline Demo | Built with β€οΈ using Gradio, Hugging Face, and FAISS<br> | |
π€ Models: sentence-transformers/all-mpnet-base-v2 | deepset/roberta-base-squad2 | |
</div> | |
""") | |
return interface | |
# Launch the app for Hugging Face Spaces | |
if __name__ == "__main__": | |
interface = create_interface() | |
interface.launch() | |