VitaliyPolovyyEN's picture
Update app.py
4e72327 verified
raw
history blame
11.2 kB
import gradio as gr
import time
import datetime
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import traceback
from langchain.text_splitter import RecursiveCharacterTextSplitter
import io
# Configuration
EMBEDDING_MODELS = {
"sentence-transformers/all-MiniLM-L6-v2": "MiniLM (Multilingual)",
"ai-forever/FRIDA": "FRIDA (RU-EN)",
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": "Multilingual MiniLM",
"cointegrated/rubert-tiny2": "RuBERT Tiny",
"ai-forever/sbert_large_nlu_ru": "Russian SBERT Large"
}
CHUNK_SIZE = 1024
CHUNK_OVERLAP = 200
TOP_K_RESULTS = 4
OUTPUT_FILENAME = "rag_embedding_test_results.txt"
# Global storage
embeddings_cache = {}
document_chunks = []
current_document = ""
def chunk_document(text):
"""Split document into chunks using RecursiveCharacterTextSplitter"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
length_function=len,
)
chunks = text_splitter.split_text(text)
return [chunk for chunk in chunks if len(chunk.strip()) > 50]
def test_single_model(model_name, chunks, question):
"""Test embedding with a single model"""
try:
start_time = time.time()
# Load model
model = SentenceTransformer(model_name)
load_time = time.time() - start_time
# Create embeddings
embed_start = time.time()
chunk_embeddings = model.encode(chunks, show_progress_bar=False)
question_embedding = model.encode([question], show_progress_bar=False)
embed_time = time.time() - embed_start
# Calculate similarities
similarities = cosine_similarity(question_embedding, chunk_embeddings)[0]
# Get top K results
top_indices = np.argsort(similarities)[-TOP_K_RESULTS:][::-1]
total_time = time.time() - start_time
results = {
'status': 'success',
'total_time': total_time,
'load_time': load_time,
'embed_time': embed_time,
'top_chunks': [
{
'index': idx,
'score': similarities[idx],
'text': chunks[idx]
}
for idx in top_indices
]
}
return results
except Exception as e:
return {
'status': 'failed',
'error': str(e),
'traceback': traceback.format_exc()
}
def process_embeddings(document_text, progress=gr.Progress()):
"""Process document with all embedding models"""
global embeddings_cache, document_chunks, current_document
if not document_text.strip():
return "❌ Please provide document text first!"
current_document = document_text
# Chunk document
progress(0.1, desc="Chunking document...")
document_chunks = chunk_document(document_text)
if not document_chunks:
return "❌ No valid chunks created. Please provide longer text."
embeddings_cache = {}
total_models = len(EMBEDDING_MODELS)
progress(0.2, desc=f"Processing {len(document_chunks)} chunks with {total_models} models...")
# Process each model
for i, (model_name, display_name) in enumerate(EMBEDDING_MODELS.items()):
progress(0.2 + (0.7 * i / total_models), desc=f"Testing {display_name}...")
# This is just preparation - we'll process on query
embeddings_cache[model_name] = {
'processed': False,
'display_name': display_name
}
progress(1.0, desc="Ready for testing!")
return f"βœ… Document processed successfully!\n\nπŸ“Š **Stats:**\n- Total chunks: {len(document_chunks)}\n- Chunk size: {CHUNK_SIZE}\n- Chunk overlap: {CHUNK_OVERLAP}\n- Models ready: {len(EMBEDDING_MODELS)}\n\nπŸ” **Now ask a question to compare embedding models!**"
def compare_embeddings(question, progress=gr.Progress()):
"""Compare all models for a given question"""
global embeddings_cache, document_chunks
if not question.strip():
return "❌ Please enter a question!", ""
if not document_chunks:
return "❌ Please process a document first using 'Start Embedding' button!", ""
results = {}
total_models = len(EMBEDDING_MODELS)
# Test each model
for i, (model_name, display_name) in enumerate(EMBEDDING_MODELS.items()):
progress(i / total_models, desc=f"Testing {display_name}...")
result = test_single_model(model_name, document_chunks, question)
results[model_name] = result
results[model_name]['display_name'] = display_name
progress(1.0, desc="Comparison complete!")
# Format results for display
display_results = format_comparison_results(results, question)
# Generate downloadable report
report_content = generate_report(results, question)
return display_results, report_content
def format_comparison_results(results, question):
"""Format results for Gradio display"""
output = f"# πŸ” Embedding Model Comparison\n\n"
output += f"**Question:** {question}\n\n"
output += f"**Document chunks:** {len(document_chunks)}\n\n"
output += "---\n\n"
for model_name, result in results.items():
display_name = result['display_name']
output += f"## πŸ€– {display_name}\n\n"
if result['status'] == 'success':
output += f"βœ… **Success** ({result['total_time']:.2f}s)\n\n"
output += "**Top Results:**\n\n"
for i, chunk in enumerate(result['top_chunks'], 1):
score = chunk['score']
text_preview = chunk['text'][:200] + "..." if len(chunk['text']) > 200 else chunk['text']
output += f"**{i}. [{score:.3f}]** Chunk #{chunk['index']}\n"
output += f"```\n{text_preview}\n```\n\n"
else:
output += f"❌ **Failed:** {result['error']}\n\n"
output += "---\n\n"
return output
def generate_report(results, question):
"""Generate downloadable text report"""
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
report = "==========================================\n"
report += "RAG EMBEDDING MODEL TEST RESULTS\n"
report += "==========================================\n"
report += f"Date: {timestamp}\n"
report += f"Question: {question}\n"
report += f"Document chunks: {len(document_chunks)}\n\n"
report += "Settings:\n"
report += f"- Chunk Size: {CHUNK_SIZE}\n"
report += f"- Chunk Overlap: {CHUNK_OVERLAP}\n"
report += f"- Splitter: RecursiveCharacterTextSplitter\n"
report += f"- Top-K Results: {TOP_K_RESULTS}\n\n"
report += "==========================================\n"
for model_name, result in results.items():
display_name = result['display_name']
report += f"MODEL: {display_name}\n"
if result['status'] == 'success':
report += f"Status: βœ… Success ({result['total_time']:.2f}s)\n"
report += "Top Results:\n"
for chunk in result['top_chunks']:
score = chunk['score']
text = chunk['text'].replace('\n', ' ')
text_preview = text[:100] + "..." if len(text) > 100 else text
report += f"[{score:.3f}] Chunk #{chunk['index']}: \"{text_preview}\"\n"
else:
report += f"Status: ❌ Failed - {result['error']}\n"
report += "\n" + "="*40 + "\n"
return report
def load_file(file):
"""Load content from uploaded file"""
if file is None:
return ""
try:
content = file.read()
if isinstance(content, bytes):
content = content.decode('utf-8')
return content
except Exception as e:
return f"Error loading file: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="RAG Embedding Model Tester", theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ§ͺ RAG Embedding Model Tester")
gr.Markdown("Test and compare different embedding models for RAG pipelines. Focus on relevance quality assessment.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## πŸ“„ Document Input")
document_input = gr.Textbox(
lines=15,
placeholder="Paste your document text here (Russian or English)...",
label="Document Text",
max_lines=20
)
file_input = gr.File(
file_types=[".txt", ".md"],
label="Or Upload Text File"
)
# Load file content to text box
file_input.change(
fn=load_file,
inputs=file_input,
outputs=document_input
)
embed_btn = gr.Button("πŸš€ Start Embedding Process", variant="primary", size="lg")
embed_status = gr.Textbox(label="Processing Status", lines=8)
with gr.Column(scale=2):
gr.Markdown("## ❓ Question & Comparison")
question_input = gr.Textbox(
placeholder="What question do you want to ask about the document?",
label="Your Question",
lines=2
)
compare_btn = gr.Button("πŸ” Compare All Models", variant="secondary", size="lg")
results_display = gr.Markdown(label="Comparison Results")
gr.Markdown("## πŸ“₯ Download Results")
report_download = gr.File(label="Download Test Report")
# Model info
with gr.Row():
gr.Markdown(f"""
## πŸ€– Models to Test:
{', '.join([f"**{name}**" for name in EMBEDDING_MODELS.values()])}
## βš™οΈ Settings:
- **Chunk Size:** {CHUNK_SIZE} characters
- **Chunk Overlap:** {CHUNK_OVERLAP} characters
- **Top Results:** {TOP_K_RESULTS} chunks per model
- **Splitter:** RecursiveCharacterTextSplitter
""")
# Event handlers
embed_btn.click(
fn=process_embeddings,
inputs=document_input,
outputs=embed_status
)
def compare_and_download(question):
results_text, report_content = compare_embeddings(question)
# Create downloadable file
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"rag_test_{timestamp}.txt"
# Save report to file-like object
report_file = io.StringIO(report_content)
report_file.name = filename
return results_text, gr.File.update(value=report_file.getvalue(), visible=True)
compare_btn.click(
fn=compare_and_download,
inputs=question_input,
outputs=[results_display, report_download]
)
if __name__ == "__main__":
demo.launch()