Spaces:
Build error
Build error
File size: 4,719 Bytes
7f0844d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
import utils
from datasets import load_dataset, concatenate_datasets
from langchain.docstore.document import Document as LangchainDocument
from tqdm import tqdm
import pickle
from ragatouille import RAGPretrainedModel
import chunker
import retriver
import rag
import nltk
import config
import os
import warnings
import sys
import logging
logging.getLogger("langchain").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")
class AnswerSystem:
def __init__(self, rag_system) -> None:
self.rag_system = rag_system
def answer_generate(self, question, bm_25_flag, semantic_flag, temperature):
answer, relevant_docs = self.rag_system.answer(
question=question,
temperature=temperature,
bm_25_flag=bm_25_flag,
semantic_flag=semantic_flag,
num_retrieved_docs = 10,
num_docs_final = 5
)
formatted_docs = "\n\n".join([f"Document {i + 1}: {doc}" for i, doc in enumerate(relevant_docs)])
return answer, formatted_docs
def run_app(rag_model):
with gr.Blocks() as demo:
gr.Markdown(
"""
# RealTimeData Monthly Collection - BBC News Documentation Assistant
Welcome! This system is designed to help you explore and find insights from the RealTimeData Monthly Collection - BBC News dataset.
For example:
- *"What position does Josko Gvardiol play, and how much did Manchester City pay for him?"*
"""
)
# Поля вводу
question_input = gr.Textbox(label="Enter your question:",
placeholder="E.g., What position does Josko Gvardiol play, and how much did Manchester City pay for him?")
bm25_checkbox = gr.Checkbox(label="Enable BM25-based retrieval", value=True) # BM25 flag
semantic_checkbox = gr.Checkbox(label="Enable Semantic Search", value=True) # Semantic flag
temperature_slider = gr.Slider(label="Response Temperature", minimum=0.1, maximum=1.0, value=0.5,
step=0.1) # Temperature
# Кнопка пошуку
search_button = gr.Button("Search")
# Поля виводу
answer_output = gr.Textbox(label="Answer", interactive=False, lines=5)
docs_output = gr.Textbox(label="Relevant Documents", interactive=False, lines=10)
# Логіка пошуку
system = AnswerSystem(rag_model)
search_button.click(
system.answer_generate,
inputs=[question_input, bm25_checkbox, semantic_checkbox, temperature_slider], # Всі параметри
outputs=[answer_output, docs_output]
)
# Запуск додатку
demo.launch(debug=True, share=True)
def get_rag_data():
nltk.download('punkt')
nltk.download('punkt_tab')
if os.path.exists(config.DOCUMENTS_PATH):
print(f"Loading preprocessed documents from {config.DOCUMENTS_PATH}")
with open(config.DOCUMENTS_PATH, "rb") as file:
docs_processed = pickle.load(file)
else:
print("Processing documents...")
datasets_list = [
utils.align_features(load_dataset("RealTimeData/bbc_news_alltime", config)["train"])
for config in tqdm(config.AVAILABLE_DATASET_CONFIGS)
]
ds = concatenate_datasets(datasets_list)
RAW_KNOWLEDGE_BASE = [
LangchainDocument(
page_content=doc["content"],
metadata={
"title": doc["title"],
"published_date": doc["published_date"],
"authors": doc["authors"],
"section": doc["section"],
"description": doc["description"],
"link": doc["link"]
}
)
for doc in tqdm(ds)
]
docs_processed = chunker.split_documents(512, RAW_KNOWLEDGE_BASE)
print(f"Saving preprocessed documents to {config.DOCUMENTS_PATH}")
with open(config.DOCUMENTS_PATH, "wb") as file:
pickle.dump(docs_processed, file)
return docs_processed
if __name__ == '__main__':
docs_processed = get_rag_data()
bm25 = retriver.create_bm25(docs_processed)
KNOWLEDGE_VECTOR_DATABASE = retriver.create_vector_db(docs_processed)
RERANKER = RAGPretrainedModel.from_pretrained(config.CROSS_ENCODER_MODEL)
rag_generator = rag.RAGAnswerGenerator(
docs=docs_processed,
bm25=bm25,
knowledge_index=KNOWLEDGE_VECTOR_DATABASE,
reranker=RERANKER
)
run_app(rag_generator) |