|
import gradio as gr |
|
import time |
|
import datetime |
|
from sentence_transformers import SentenceTransformer |
|
import numpy as np |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import traceback |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
import io |
|
|
|
|
|
EMBEDDING_MODELS = { |
|
"sentence-transformers/all-MiniLM-L6-v2": "MiniLM (Multilingual)", |
|
"ai-forever/FRIDA": "FRIDA (RU-EN)", |
|
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2": "Multilingual MiniLM", |
|
"cointegrated/rubert-tiny2": "RuBERT Tiny", |
|
"ai-forever/sbert_large_nlu_ru": "Russian SBERT Large" |
|
} |
|
|
|
CHUNK_SIZE = 1024 |
|
CHUNK_OVERLAP = 200 |
|
TOP_K_RESULTS = 4 |
|
OUTPUT_FILENAME = "rag_embedding_test_results.txt" |
|
|
|
|
|
embeddings_cache = {} |
|
document_chunks = [] |
|
current_document = "" |
|
|
|
def chunk_document(text): |
|
"""Split document into chunks using RecursiveCharacterTextSplitter""" |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=CHUNK_SIZE, |
|
chunk_overlap=CHUNK_OVERLAP, |
|
length_function=len, |
|
) |
|
chunks = text_splitter.split_text(text) |
|
return [chunk for chunk in chunks if len(chunk.strip()) > 50] |
|
|
|
def test_single_model(model_name, chunks, question): |
|
"""Test embedding with a single model""" |
|
try: |
|
start_time = time.time() |
|
|
|
|
|
model = SentenceTransformer(model_name) |
|
load_time = time.time() - start_time |
|
|
|
|
|
embed_start = time.time() |
|
chunk_embeddings = model.encode(chunks, show_progress_bar=False) |
|
question_embedding = model.encode([question], show_progress_bar=False) |
|
embed_time = time.time() - embed_start |
|
|
|
|
|
similarities = cosine_similarity(question_embedding, chunk_embeddings)[0] |
|
|
|
|
|
top_indices = np.argsort(similarities)[-TOP_K_RESULTS:][::-1] |
|
|
|
total_time = time.time() - start_time |
|
|
|
results = { |
|
'status': 'success', |
|
'total_time': total_time, |
|
'load_time': load_time, |
|
'embed_time': embed_time, |
|
'top_chunks': [ |
|
{ |
|
'index': idx, |
|
'score': similarities[idx], |
|
'text': chunks[idx] |
|
} |
|
for idx in top_indices |
|
] |
|
} |
|
|
|
return results |
|
|
|
except Exception as e: |
|
return { |
|
'status': 'failed', |
|
'error': str(e), |
|
'traceback': traceback.format_exc() |
|
} |
|
|
|
def process_embeddings(document_text, progress=gr.Progress()): |
|
"""Process document with all embedding models""" |
|
global embeddings_cache, document_chunks, current_document |
|
|
|
if not document_text.strip(): |
|
return "β Please provide document text first!" |
|
|
|
current_document = document_text |
|
|
|
|
|
progress(0.1, desc="Chunking document...") |
|
document_chunks = chunk_document(document_text) |
|
|
|
if not document_chunks: |
|
return "β No valid chunks created. Please provide longer text." |
|
|
|
embeddings_cache = {} |
|
total_models = len(EMBEDDING_MODELS) |
|
|
|
progress(0.2, desc=f"Processing {len(document_chunks)} chunks with {total_models} models...") |
|
|
|
|
|
for i, (model_name, display_name) in enumerate(EMBEDDING_MODELS.items()): |
|
progress(0.2 + (0.7 * i / total_models), desc=f"Testing {display_name}...") |
|
|
|
|
|
embeddings_cache[model_name] = { |
|
'processed': False, |
|
'display_name': display_name |
|
} |
|
|
|
progress(1.0, desc="Ready for testing!") |
|
|
|
return f"β
Document processed successfully!\n\nπ **Stats:**\n- Total chunks: {len(document_chunks)}\n- Chunk size: {CHUNK_SIZE}\n- Chunk overlap: {CHUNK_OVERLAP}\n- Models ready: {len(EMBEDDING_MODELS)}\n\nπ **Now ask a question to compare embedding models!**" |
|
|
|
def compare_embeddings(question, progress=gr.Progress()): |
|
"""Compare all models for a given question""" |
|
global embeddings_cache, document_chunks |
|
|
|
if not question.strip(): |
|
return "β Please enter a question!", "" |
|
|
|
if not document_chunks: |
|
return "β Please process a document first using 'Start Embedding' button!", "" |
|
|
|
results = {} |
|
total_models = len(EMBEDDING_MODELS) |
|
|
|
|
|
for i, (model_name, display_name) in enumerate(EMBEDDING_MODELS.items()): |
|
progress(i / total_models, desc=f"Testing {display_name}...") |
|
|
|
result = test_single_model(model_name, document_chunks, question) |
|
results[model_name] = result |
|
results[model_name]['display_name'] = display_name |
|
|
|
progress(1.0, desc="Comparison complete!") |
|
|
|
|
|
display_results = format_comparison_results(results, question) |
|
|
|
|
|
report_content = generate_report(results, question) |
|
|
|
return display_results, report_content |
|
|
|
def format_comparison_results(results, question): |
|
"""Format results for Gradio display""" |
|
output = f"# π Embedding Model Comparison\n\n" |
|
output += f"**Question:** {question}\n\n" |
|
output += f"**Document chunks:** {len(document_chunks)}\n\n" |
|
output += "---\n\n" |
|
|
|
for model_name, result in results.items(): |
|
display_name = result['display_name'] |
|
output += f"## π€ {display_name}\n\n" |
|
|
|
if result['status'] == 'success': |
|
output += f"β
**Success** ({result['total_time']:.2f}s)\n\n" |
|
output += "**Top Results:**\n\n" |
|
|
|
for i, chunk in enumerate(result['top_chunks'], 1): |
|
score = chunk['score'] |
|
text_preview = chunk['text'][:200] + "..." if len(chunk['text']) > 200 else chunk['text'] |
|
output += f"**{i}. [{score:.3f}]** Chunk #{chunk['index']}\n" |
|
output += f"```\n{text_preview}\n```\n\n" |
|
else: |
|
output += f"β **Failed:** {result['error']}\n\n" |
|
|
|
output += "---\n\n" |
|
|
|
return output |
|
|
|
def generate_report(results, question): |
|
"""Generate downloadable text report""" |
|
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
report = "==========================================\n" |
|
report += "RAG EMBEDDING MODEL TEST RESULTS\n" |
|
report += "==========================================\n" |
|
report += f"Date: {timestamp}\n" |
|
report += f"Question: {question}\n" |
|
report += f"Document chunks: {len(document_chunks)}\n\n" |
|
|
|
report += "Settings:\n" |
|
report += f"- Chunk Size: {CHUNK_SIZE}\n" |
|
report += f"- Chunk Overlap: {CHUNK_OVERLAP}\n" |
|
report += f"- Splitter: RecursiveCharacterTextSplitter\n" |
|
report += f"- Top-K Results: {TOP_K_RESULTS}\n\n" |
|
|
|
report += "==========================================\n" |
|
|
|
for model_name, result in results.items(): |
|
display_name = result['display_name'] |
|
report += f"MODEL: {display_name}\n" |
|
|
|
if result['status'] == 'success': |
|
report += f"Status: β
Success ({result['total_time']:.2f}s)\n" |
|
report += "Top Results:\n" |
|
|
|
for chunk in result['top_chunks']: |
|
score = chunk['score'] |
|
text = chunk['text'].replace('\n', ' ') |
|
text_preview = text[:100] + "..." if len(text) > 100 else text |
|
report += f"[{score:.3f}] Chunk #{chunk['index']}: \"{text_preview}\"\n" |
|
else: |
|
report += f"Status: β Failed - {result['error']}\n" |
|
|
|
report += "\n" + "="*40 + "\n" |
|
|
|
return report |
|
|
|
def load_file(file): |
|
"""Load content from uploaded file""" |
|
if file is None: |
|
return "" |
|
|
|
try: |
|
content = file.read() |
|
if isinstance(content, bytes): |
|
content = content.decode('utf-8') |
|
return content |
|
except Exception as e: |
|
return f"Error loading file: {str(e)}" |
|
|
|
|
|
with gr.Blocks(title="RAG Embedding Model Tester", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# π§ͺ RAG Embedding Model Tester") |
|
gr.Markdown("Test and compare different embedding models for RAG pipelines. Focus on relevance quality assessment.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("## π Document Input") |
|
|
|
document_input = gr.Textbox( |
|
lines=15, |
|
placeholder="Paste your document text here (Russian or English)...", |
|
label="Document Text", |
|
max_lines=20 |
|
) |
|
|
|
file_input = gr.File( |
|
file_types=[".txt", ".md"], |
|
label="Or Upload Text File" |
|
) |
|
|
|
|
|
file_input.change( |
|
fn=load_file, |
|
inputs=file_input, |
|
outputs=document_input |
|
) |
|
|
|
embed_btn = gr.Button("π Start Embedding Process", variant="primary", size="lg") |
|
embed_status = gr.Textbox(label="Processing Status", lines=8) |
|
|
|
with gr.Column(scale=2): |
|
gr.Markdown("## β Question & Comparison") |
|
|
|
question_input = gr.Textbox( |
|
placeholder="What question do you want to ask about the document?", |
|
label="Your Question", |
|
lines=2 |
|
) |
|
|
|
compare_btn = gr.Button("π Compare All Models", variant="secondary", size="lg") |
|
|
|
results_display = gr.Markdown(label="Comparison Results") |
|
|
|
gr.Markdown("## π₯ Download Results") |
|
report_download = gr.File(label="Download Test Report") |
|
|
|
|
|
with gr.Row(): |
|
gr.Markdown(f""" |
|
## π€ Models to Test: |
|
{', '.join([f"**{name}**" for name in EMBEDDING_MODELS.values()])} |
|
|
|
## βοΈ Settings: |
|
- **Chunk Size:** {CHUNK_SIZE} characters |
|
- **Chunk Overlap:** {CHUNK_OVERLAP} characters |
|
- **Top Results:** {TOP_K_RESULTS} chunks per model |
|
- **Splitter:** RecursiveCharacterTextSplitter |
|
""") |
|
|
|
|
|
embed_btn.click( |
|
fn=process_embeddings, |
|
inputs=document_input, |
|
outputs=embed_status |
|
) |
|
|
|
def compare_and_download(question): |
|
results_text, report_content = compare_embeddings(question) |
|
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
filename = f"rag_test_{timestamp}.txt" |
|
|
|
|
|
report_file = io.StringIO(report_content) |
|
report_file.name = filename |
|
|
|
return results_text, gr.File.update(value=report_file.getvalue(), visible=True) |
|
|
|
compare_btn.click( |
|
fn=compare_and_download, |
|
inputs=question_input, |
|
outputs=[results_display, report_download] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |