File size: 4,719 Bytes
7f0844d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import gradio as gr
import utils
from datasets import load_dataset, concatenate_datasets
from langchain.docstore.document import Document as LangchainDocument
from tqdm import tqdm
import pickle
from ragatouille import RAGPretrainedModel
import chunker
import retriver
import rag
import nltk
import config
import os
import warnings
import sys
import logging

logging.getLogger("langchain").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")


class AnswerSystem:
    def __init__(self, rag_system) -> None:
        self.rag_system = rag_system
    
    def answer_generate(self, question, bm_25_flag, semantic_flag, temperature):                
        answer, relevant_docs = self.rag_system.answer(
            question=question,
            temperature=temperature,
            bm_25_flag=bm_25_flag,
            semantic_flag=semantic_flag,
            num_retrieved_docs = 10,
            num_docs_final = 5
        )
        formatted_docs = "\n\n".join([f"Document {i + 1}: {doc}" for i, doc in enumerate(relevant_docs)])
        return answer, formatted_docs


def run_app(rag_model):
    with gr.Blocks() as demo:
        gr.Markdown(
            """
            # RealTimeData Monthly Collection - BBC News Documentation Assistant

            Welcome! This system is designed to help you explore and find insights from the RealTimeData Monthly Collection - BBC News dataset.  
            For example:  
            
            - *"What position does Josko Gvardiol play, and how much did Manchester City pay for him?"*  

            """
        )

        # Поля вводу
        question_input = gr.Textbox(label="Enter your question:",
                                    placeholder="E.g., What position does Josko Gvardiol play, and how much did Manchester City pay for him?")
        bm25_checkbox = gr.Checkbox(label="Enable BM25-based retrieval", value=True)  # BM25 flag
        semantic_checkbox = gr.Checkbox(label="Enable Semantic Search", value=True)  # Semantic flag
        temperature_slider = gr.Slider(label="Response Temperature", minimum=0.1, maximum=1.0, value=0.5,
                                       step=0.1)  # Temperature

        # Кнопка пошуку
        search_button = gr.Button("Search")

        # Поля виводу
        answer_output = gr.Textbox(label="Answer", interactive=False, lines=5)
        docs_output = gr.Textbox(label="Relevant Documents", interactive=False, lines=10)

        # Логіка пошуку
        system = AnswerSystem(rag_model)

        search_button.click(
            system.answer_generate,
            inputs=[question_input, bm25_checkbox, semantic_checkbox, temperature_slider],  # Всі параметри
            outputs=[answer_output, docs_output]
        )

    # Запуск додатку
    demo.launch(debug=True, share=True)


def get_rag_data():
    nltk.download('punkt')
    nltk.download('punkt_tab')

    if os.path.exists(config.DOCUMENTS_PATH):
        print(f"Loading preprocessed documents from {config.DOCUMENTS_PATH}")
        with open(config.DOCUMENTS_PATH, "rb") as file:
            docs_processed = pickle.load(file)
    else:
        print("Processing documents...")
        datasets_list = [
            utils.align_features(load_dataset("RealTimeData/bbc_news_alltime", config)["train"])
            for config in tqdm(config.AVAILABLE_DATASET_CONFIGS)
        ]

        ds = concatenate_datasets(datasets_list)

        RAW_KNOWLEDGE_BASE = [
            LangchainDocument(
                page_content=doc["content"],
                metadata={
                    "title": doc["title"],
                    "published_date": doc["published_date"],
                    "authors": doc["authors"],
                    "section": doc["section"],
                    "description": doc["description"],
                    "link": doc["link"]
                }
            )
            for doc in tqdm(ds)
        ]

        docs_processed = chunker.split_documents(512, RAW_KNOWLEDGE_BASE)

        print(f"Saving preprocessed documents to {config.DOCUMENTS_PATH}")
        with open(config.DOCUMENTS_PATH, "wb") as file:
            pickle.dump(docs_processed, file)

    return docs_processed


if __name__ == '__main__':
    docs_processed = get_rag_data()

    bm25 = retriver.create_bm25(docs_processed)

    KNOWLEDGE_VECTOR_DATABASE = retriver.create_vector_db(docs_processed)

    RERANKER = RAGPretrainedModel.from_pretrained(config.CROSS_ENCODER_MODEL)

    rag_generator = rag.RAGAnswerGenerator(
        docs=docs_processed,
        bm25=bm25,
        knowledge_index=KNOWLEDGE_VECTOR_DATABASE,
        reranker=RERANKER
    )

    run_app(rag_generator)