Vasyl808 commited on
Commit
7f0844d
·
1 Parent(s): 5952931

Add application file

Browse files
Files changed (6) hide show
  1. app.py +138 -0
  2. chunker.py +45 -0
  3. config.py +28 -0
  4. rag.py +100 -0
  5. retriver.py +96 -0
  6. utils.py +12 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import utils
3
+ from datasets import load_dataset, concatenate_datasets
4
+ from langchain.docstore.document import Document as LangchainDocument
5
+ from tqdm import tqdm
6
+ import pickle
7
+ from ragatouille import RAGPretrainedModel
8
+ import chunker
9
+ import retriver
10
+ import rag
11
+ import nltk
12
+ import config
13
+ import os
14
+ import warnings
15
+ import sys
16
+ import logging
17
+
18
+ logging.getLogger("langchain").setLevel(logging.ERROR)
19
+ warnings.filterwarnings("ignore")
20
+
21
+
22
+ class AnswerSystem:
23
+ def __init__(self, rag_system) -> None:
24
+ self.rag_system = rag_system
25
+
26
+ def answer_generate(self, question, bm_25_flag, semantic_flag, temperature):
27
+ answer, relevant_docs = self.rag_system.answer(
28
+ question=question,
29
+ temperature=temperature,
30
+ bm_25_flag=bm_25_flag,
31
+ semantic_flag=semantic_flag,
32
+ num_retrieved_docs = 10,
33
+ num_docs_final = 5
34
+ )
35
+ formatted_docs = "\n\n".join([f"Document {i + 1}: {doc}" for i, doc in enumerate(relevant_docs)])
36
+ return answer, formatted_docs
37
+
38
+
39
+ def run_app(rag_model):
40
+ with gr.Blocks() as demo:
41
+ gr.Markdown(
42
+ """
43
+ # RealTimeData Monthly Collection - BBC News Documentation Assistant
44
+
45
+ Welcome! This system is designed to help you explore and find insights from the RealTimeData Monthly Collection - BBC News dataset.
46
+ For example:
47
+
48
+ - *"What position does Josko Gvardiol play, and how much did Manchester City pay for him?"*
49
+
50
+ """
51
+ )
52
+
53
+ # Поля вводу
54
+ question_input = gr.Textbox(label="Enter your question:",
55
+ placeholder="E.g., What position does Josko Gvardiol play, and how much did Manchester City pay for him?")
56
+ bm25_checkbox = gr.Checkbox(label="Enable BM25-based retrieval", value=True) # BM25 flag
57
+ semantic_checkbox = gr.Checkbox(label="Enable Semantic Search", value=True) # Semantic flag
58
+ temperature_slider = gr.Slider(label="Response Temperature", minimum=0.1, maximum=1.0, value=0.5,
59
+ step=0.1) # Temperature
60
+
61
+ # Кнопка пошуку
62
+ search_button = gr.Button("Search")
63
+
64
+ # Поля виводу
65
+ answer_output = gr.Textbox(label="Answer", interactive=False, lines=5)
66
+ docs_output = gr.Textbox(label="Relevant Documents", interactive=False, lines=10)
67
+
68
+ # Логіка пошуку
69
+ system = AnswerSystem(rag_model)
70
+
71
+ search_button.click(
72
+ system.answer_generate,
73
+ inputs=[question_input, bm25_checkbox, semantic_checkbox, temperature_slider], # Всі параметри
74
+ outputs=[answer_output, docs_output]
75
+ )
76
+
77
+ # Запуск додатку
78
+ demo.launch(debug=True, share=True)
79
+
80
+
81
+ def get_rag_data():
82
+ nltk.download('punkt')
83
+ nltk.download('punkt_tab')
84
+
85
+ if os.path.exists(config.DOCUMENTS_PATH):
86
+ print(f"Loading preprocessed documents from {config.DOCUMENTS_PATH}")
87
+ with open(config.DOCUMENTS_PATH, "rb") as file:
88
+ docs_processed = pickle.load(file)
89
+ else:
90
+ print("Processing documents...")
91
+ datasets_list = [
92
+ utils.align_features(load_dataset("RealTimeData/bbc_news_alltime", config)["train"])
93
+ for config in tqdm(config.AVAILABLE_DATASET_CONFIGS)
94
+ ]
95
+
96
+ ds = concatenate_datasets(datasets_list)
97
+
98
+ RAW_KNOWLEDGE_BASE = [
99
+ LangchainDocument(
100
+ page_content=doc["content"],
101
+ metadata={
102
+ "title": doc["title"],
103
+ "published_date": doc["published_date"],
104
+ "authors": doc["authors"],
105
+ "section": doc["section"],
106
+ "description": doc["description"],
107
+ "link": doc["link"]
108
+ }
109
+ )
110
+ for doc in tqdm(ds)
111
+ ]
112
+
113
+ docs_processed = chunker.split_documents(512, RAW_KNOWLEDGE_BASE)
114
+
115
+ print(f"Saving preprocessed documents to {config.DOCUMENTS_PATH}")
116
+ with open(config.DOCUMENTS_PATH, "wb") as file:
117
+ pickle.dump(docs_processed, file)
118
+
119
+ return docs_processed
120
+
121
+
122
+ if __name__ == '__main__':
123
+ docs_processed = get_rag_data()
124
+
125
+ bm25 = retriver.create_bm25(docs_processed)
126
+
127
+ KNOWLEDGE_VECTOR_DATABASE = retriver.create_vector_db(docs_processed)
128
+
129
+ RERANKER = RAGPretrainedModel.from_pretrained(config.CROSS_ENCODER_MODEL)
130
+
131
+ rag_generator = rag.RAGAnswerGenerator(
132
+ docs=docs_processed,
133
+ bm25=bm25,
134
+ knowledge_index=KNOWLEDGE_VECTOR_DATABASE,
135
+ reranker=RERANKER
136
+ )
137
+
138
+ run_app(rag_generator)
chunker.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ from langchain.docstore.document import Document as LangchainDocument
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from transformers import AutoTokenizer
5
+ from tqdm.notebook import tqdm
6
+ from typing import List
7
+
8
+
9
+ def split_documents(chunk_size: int, knowledge_base: List[LangchainDocument]) -> List[LangchainDocument]:
10
+ """
11
+ Split documents into chunks of maximum size `chunk_size` tokens and return a list of documents.
12
+ """
13
+ MARKDOWN_SEPARATORS = [
14
+ "\n#{1,6} ",
15
+ "```\n",
16
+ "\n\\*\\*\\*+\n",
17
+ "\n---+\n",
18
+ "\n___+\n",
19
+ "\n\n",
20
+ "\n",
21
+ " ",
22
+ "",
23
+ ]
24
+
25
+ text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
26
+ AutoTokenizer.from_pretrained(config.EMBEDDING_MODEL_NAME),
27
+ chunk_size=chunk_size,
28
+ chunk_overlap=int(chunk_size / 10),
29
+ add_start_index=True,
30
+ strip_whitespace=True,
31
+ separators=MARKDOWN_SEPARATORS,
32
+ )
33
+
34
+ docs_processed = []
35
+ for doc in tqdm(knowledge_base):
36
+ docs_processed += text_splitter.split_documents([doc])
37
+
38
+ unique_texts = {}
39
+ docs_processed_unique = []
40
+ for doc in docs_processed:
41
+ if doc.page_content not in unique_texts:
42
+ unique_texts[doc.page_content] = True
43
+ docs_processed_unique.append(doc)
44
+
45
+ return docs_processed_unique # , docs_processed
config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ RAG_PROMPT = """
5
+ You are an advanced Retrieval-Augmented Generation (RAG) Assistant.
6
+
7
+ Your task is to answer user questions based only on the provided documents. Use the context from the documents to generate a response.
8
+
9
+ **Guidelines:**
10
+ 1. **Always cite sources**: When information is derived from a document, reference it by citing the chunk number in square brackets, e.g., [Chunk 1], where relevant information is used.
11
+ 2. If the answer cannot be determined from the provided documents, state: "The answer cannot be determined from the provided documents."
12
+ 3. After each answer, provide a numbered list of the retrieved chunks.
13
+
14
+ Please follow these instructions to generate accurate and well-cited answers based on the documents.
15
+ """
16
+ LLM_ONLY_PROMPT = """You are an Assistant. If no documents are retrieved, answer the question based on general knowledge."""
17
+
18
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
19
+ os.environ['GROQ_API_KEY'] = "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"
20
+ DB_PATH = "vector_database.faiss"
21
+ BM25_PATH = "bm25_index.pkl"
22
+ DOCUMENTS_PATH = "processed_documents.pkl"
23
+ EMBEDDING_MODEL_NAME = "thenlper/gte-small"
24
+ CROSS_ENCODER_MODEL = "colbert-ir/colbertv2.0"
25
+
26
+ AVAILABLE_DATASET_CONFIGS = [
27
+ '2024-11'
28
+ ]
rag.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Tuple
2
+ from langchain.docstore.document import Document as LangchainDocument
3
+ from rank_bm25 import BM25Okapi
4
+ from langchain_community.vectorstores import FAISS
5
+ from ragatouille import RAGPretrainedModel
6
+ from litellm import completion
7
+ import os
8
+ import retriver
9
+ import config
10
+
11
+
12
+ class RAGAnswerGenerator:
13
+ def __init__(self, docs: List[LangchainDocument], bm25: BM25Okapi, knowledge_index: FAISS, reranker: Optional[RAGPretrainedModel] = None):
14
+ self.bm25 = bm25
15
+ self.knowledge_index = knowledge_index
16
+ self.docs = docs
17
+ self.reranker = reranker
18
+ self.llm_key = os.environ['GROQ_API_KEY']
19
+
20
+ def retrieve_documents(
21
+ self,
22
+ question: str,
23
+ num_retrieved_docs: int,
24
+ bm_25_flag: bool,
25
+ semantic_flag: bool
26
+ ) -> List[str]:
27
+ print("=> Retrieving documents...")
28
+ relevant_docs = []
29
+
30
+ if bm_25_flag or semantic_flag:
31
+ result = retriver.search(
32
+ self.docs,
33
+ self.bm25,
34
+ self.knowledge_index,
35
+ question,
36
+ use_bm25=bm_25_flag,
37
+ use_semantic_search=semantic_flag,
38
+ top_k=num_retrieved_docs
39
+ )
40
+ if bm_25_flag and semantic_flag:
41
+ relevant_docs = [doc.page_content for doc in result]
42
+ return relevant_docs
43
+ elif bm_25_flag:
44
+ relevant_docs = result
45
+ return relevant_docs
46
+ elif semantic_flag:
47
+ relevant_docs = [doc.page_content for doc in result]
48
+ return relevant_docs
49
+
50
+
51
+ def rerank_documents(self, question: str, documents: List[str], num_docs_final: int) -> List[str]:
52
+ if self.reranker and documents:
53
+ print("=> Reranking documents...")
54
+ reranked_docs = self.reranker.rerank(question, documents, k=num_docs_final)
55
+ return [doc["content"] for doc in reranked_docs]
56
+ return documents[:num_docs_final]
57
+
58
+ def format_context(self, documents: List[str]) -> str:
59
+ if not documents:
60
+ return "No retrieved documents available."
61
+ return "\n".join([f"[{i + 1}] {doc}" for i, doc in enumerate(documents)])
62
+
63
+ def generate_answer(
64
+ self,
65
+ question: str,
66
+ context: str,
67
+ temperature: float,
68
+ ) -> str:
69
+ print("=> Generating answer...")
70
+ if context.strip() == "No retrieved documents available.":
71
+ response = completion(
72
+ model="groq/llama3-8b-8192",
73
+ messages=[
74
+ {"role": "system", "content": config.LLM_ONLY_PROMPT},
75
+ {"role": "user", "content": f"Question: {question}"}
76
+ ],
77
+ api_key=self.llm_key,
78
+ temperature=temperature
79
+ )
80
+ else:
81
+ response = completion(
82
+ model="groq/llama3-8b-8192",
83
+ messages=[
84
+ {"role": "system", "content": config.RAG_PROMPT},
85
+ {"role": "user", "content": f""" Context: {context} Question: {question} """}
86
+ ],
87
+ api_key=self.llm_key,
88
+ temperature=temperature
89
+ )
90
+ return response.get("choices", [{}])[0].get("message", {}).get("content", "No response content found")
91
+
92
+ def answer(self, question: str, temperature: float, num_retrieved_docs: int = 30, num_docs_final: int = 5, bm_25_flag=True, semantic_flag=True) -> Tuple[str, List[str]]:
93
+ relevant_docs = self.retrieve_documents(question, num_retrieved_docs, bm_25_flag, semantic_flag)
94
+ print(len(relevant_docs))
95
+ relevant_docs = self.rerank_documents(question, relevant_docs, num_docs_final)
96
+ print(len(relevant_docs))
97
+ context = self.format_context(relevant_docs)
98
+ answer = self.generate_answer(question, context, temperature)
99
+ document_list = [f"[{i + 1}] {doc}" for i, doc in enumerate(relevant_docs)] if relevant_docs else []
100
+ return answer, document_list
retriver.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ import utils
3
+ from nltk.tokenize import word_tokenize
4
+ from typing import List
5
+ import nltk
6
+ import torch
7
+ import pickle
8
+ from langchain.docstore.document import Document as LangchainDocument
9
+ from rank_bm25 import BM25Okapi
10
+ from langchain_community.vectorstores import FAISS
11
+ from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ from langchain_community.vectorstores.utils import DistanceStrategy
13
+ from langchain.retrievers import EnsembleRetriever
14
+ from langchain_community.retrievers import BM25Retriever
15
+ import os
16
+
17
+
18
+ def create_vector_db(docs: List[LangchainDocument]):
19
+ db_path: str = config.DB_PATH
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ embedding_model = HuggingFaceEmbeddings(
23
+ model_name=config.EMBEDDING_MODEL_NAME,
24
+ multi_process=True,
25
+ model_kwargs={"device": device},
26
+ encode_kwargs={"normalize_embeddings": True},
27
+ )
28
+
29
+ if os.path.exists(db_path):
30
+ print(f"Завантаження векторної бази даних з {db_path}")
31
+ knowledge_vector_database = FAISS.load_local(
32
+ db_path,
33
+ embedding_model,
34
+ allow_dangerous_deserialization=True
35
+ )
36
+ return knowledge_vector_database
37
+ elif docs is not None:
38
+ print("Створення нової векторної бази даних")
39
+ knowledge_vector_database = FAISS.from_documents(
40
+ docs, embedding_model, distance_strategy=DistanceStrategy.COSINE
41
+ )
42
+ knowledge_vector_database.save_local(db_path)
43
+ print(f"Векторна база даних збережена в {db_path}")
44
+ return knowledge_vector_database
45
+ else:
46
+ raise ValueError(
47
+ """Documents are missing!
48
+ Please load the documents and set get_data=True in app.py."""
49
+ )
50
+
51
+
52
+
53
+ def create_bm25(docs: List[LangchainDocument]):
54
+ bm25_path: str = config.BM25_PATH
55
+ if os.path.exists(bm25_path):
56
+ print(f"Завантаження BM25 індексу з {bm25_path}")
57
+ with open(bm25_path, "rb") as file:
58
+ bm25 = pickle.load(file)
59
+ return bm25
60
+ elif docs is not None:
61
+ print("Створення нового BM25 індексу")
62
+ tokenized_docs = [word_tokenize(doc.page_content.lower()) for doc in docs]
63
+ bm25 = BM25Okapi(tokenized_docs)
64
+ with open(bm25_path, "wb") as file:
65
+ pickle.dump(bm25, file)
66
+ print(f"BM25 індекс збережено в {bm25_path}")
67
+ return bm25
68
+ else:
69
+ raise ValueError(
70
+ """Documents are missing!
71
+ Please load the documents and set get_data=True in app.py."""
72
+ )
73
+
74
+
75
+ def search(docs_processed, bm_25: BM25Okapi, vector_db: FAISS, query, top_k, use_bm25=True, use_semantic_search=True):
76
+ if use_bm25 and use_semantic_search:
77
+ bm25_retriever = BM25Retriever.from_documents(docs_processed)
78
+ bm25_retriever.k = top_k
79
+ faiss_retriever = vector_db.as_retriever(search_kwargs={"k": top_k})
80
+ ensemble_retriever = EnsembleRetriever(
81
+ retrievers=[bm25_retriever, faiss_retriever],
82
+ weights=[0.5, 0.5]
83
+ )
84
+ result = ensemble_retriever.invoke(query)
85
+ return result
86
+
87
+ elif use_bm25:
88
+ tokenized_query = word_tokenize(query.lower())
89
+ result = bm_25.get_top_n(tokenized_query, [doc.page_content for doc in docs_processed], n=top_k)
90
+
91
+ elif use_semantic_search:
92
+ result = vector_db.similarity_search(query, k=top_k)
93
+ else:
94
+ result = []
95
+ return result
96
+
utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ from transformers import AutoTokenizer
3
+ from typing import List
4
+
5
+
6
+ def align_features(dataset):
7
+ def fix_authors(example):
8
+ if not isinstance(example["authors"], list):
9
+ return {"authors": [example["authors"]] if example["authors"] else []}
10
+ return example
11
+
12
+ return dataset.map(fix_authors)