Spaces:
Build error
Build error
Vasyl808
commited on
Commit
·
7f0844d
1
Parent(s):
5952931
Add application file
Browse files- app.py +138 -0
- chunker.py +45 -0
- config.py +28 -0
- rag.py +100 -0
- retriver.py +96 -0
- utils.py +12 -0
app.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import utils
|
3 |
+
from datasets import load_dataset, concatenate_datasets
|
4 |
+
from langchain.docstore.document import Document as LangchainDocument
|
5 |
+
from tqdm import tqdm
|
6 |
+
import pickle
|
7 |
+
from ragatouille import RAGPretrainedModel
|
8 |
+
import chunker
|
9 |
+
import retriver
|
10 |
+
import rag
|
11 |
+
import nltk
|
12 |
+
import config
|
13 |
+
import os
|
14 |
+
import warnings
|
15 |
+
import sys
|
16 |
+
import logging
|
17 |
+
|
18 |
+
logging.getLogger("langchain").setLevel(logging.ERROR)
|
19 |
+
warnings.filterwarnings("ignore")
|
20 |
+
|
21 |
+
|
22 |
+
class AnswerSystem:
|
23 |
+
def __init__(self, rag_system) -> None:
|
24 |
+
self.rag_system = rag_system
|
25 |
+
|
26 |
+
def answer_generate(self, question, bm_25_flag, semantic_flag, temperature):
|
27 |
+
answer, relevant_docs = self.rag_system.answer(
|
28 |
+
question=question,
|
29 |
+
temperature=temperature,
|
30 |
+
bm_25_flag=bm_25_flag,
|
31 |
+
semantic_flag=semantic_flag,
|
32 |
+
num_retrieved_docs = 10,
|
33 |
+
num_docs_final = 5
|
34 |
+
)
|
35 |
+
formatted_docs = "\n\n".join([f"Document {i + 1}: {doc}" for i, doc in enumerate(relevant_docs)])
|
36 |
+
return answer, formatted_docs
|
37 |
+
|
38 |
+
|
39 |
+
def run_app(rag_model):
|
40 |
+
with gr.Blocks() as demo:
|
41 |
+
gr.Markdown(
|
42 |
+
"""
|
43 |
+
# RealTimeData Monthly Collection - BBC News Documentation Assistant
|
44 |
+
|
45 |
+
Welcome! This system is designed to help you explore and find insights from the RealTimeData Monthly Collection - BBC News dataset.
|
46 |
+
For example:
|
47 |
+
|
48 |
+
- *"What position does Josko Gvardiol play, and how much did Manchester City pay for him?"*
|
49 |
+
|
50 |
+
"""
|
51 |
+
)
|
52 |
+
|
53 |
+
# Поля вводу
|
54 |
+
question_input = gr.Textbox(label="Enter your question:",
|
55 |
+
placeholder="E.g., What position does Josko Gvardiol play, and how much did Manchester City pay for him?")
|
56 |
+
bm25_checkbox = gr.Checkbox(label="Enable BM25-based retrieval", value=True) # BM25 flag
|
57 |
+
semantic_checkbox = gr.Checkbox(label="Enable Semantic Search", value=True) # Semantic flag
|
58 |
+
temperature_slider = gr.Slider(label="Response Temperature", minimum=0.1, maximum=1.0, value=0.5,
|
59 |
+
step=0.1) # Temperature
|
60 |
+
|
61 |
+
# Кнопка пошуку
|
62 |
+
search_button = gr.Button("Search")
|
63 |
+
|
64 |
+
# Поля виводу
|
65 |
+
answer_output = gr.Textbox(label="Answer", interactive=False, lines=5)
|
66 |
+
docs_output = gr.Textbox(label="Relevant Documents", interactive=False, lines=10)
|
67 |
+
|
68 |
+
# Логіка пошуку
|
69 |
+
system = AnswerSystem(rag_model)
|
70 |
+
|
71 |
+
search_button.click(
|
72 |
+
system.answer_generate,
|
73 |
+
inputs=[question_input, bm25_checkbox, semantic_checkbox, temperature_slider], # Всі параметри
|
74 |
+
outputs=[answer_output, docs_output]
|
75 |
+
)
|
76 |
+
|
77 |
+
# Запуск додатку
|
78 |
+
demo.launch(debug=True, share=True)
|
79 |
+
|
80 |
+
|
81 |
+
def get_rag_data():
|
82 |
+
nltk.download('punkt')
|
83 |
+
nltk.download('punkt_tab')
|
84 |
+
|
85 |
+
if os.path.exists(config.DOCUMENTS_PATH):
|
86 |
+
print(f"Loading preprocessed documents from {config.DOCUMENTS_PATH}")
|
87 |
+
with open(config.DOCUMENTS_PATH, "rb") as file:
|
88 |
+
docs_processed = pickle.load(file)
|
89 |
+
else:
|
90 |
+
print("Processing documents...")
|
91 |
+
datasets_list = [
|
92 |
+
utils.align_features(load_dataset("RealTimeData/bbc_news_alltime", config)["train"])
|
93 |
+
for config in tqdm(config.AVAILABLE_DATASET_CONFIGS)
|
94 |
+
]
|
95 |
+
|
96 |
+
ds = concatenate_datasets(datasets_list)
|
97 |
+
|
98 |
+
RAW_KNOWLEDGE_BASE = [
|
99 |
+
LangchainDocument(
|
100 |
+
page_content=doc["content"],
|
101 |
+
metadata={
|
102 |
+
"title": doc["title"],
|
103 |
+
"published_date": doc["published_date"],
|
104 |
+
"authors": doc["authors"],
|
105 |
+
"section": doc["section"],
|
106 |
+
"description": doc["description"],
|
107 |
+
"link": doc["link"]
|
108 |
+
}
|
109 |
+
)
|
110 |
+
for doc in tqdm(ds)
|
111 |
+
]
|
112 |
+
|
113 |
+
docs_processed = chunker.split_documents(512, RAW_KNOWLEDGE_BASE)
|
114 |
+
|
115 |
+
print(f"Saving preprocessed documents to {config.DOCUMENTS_PATH}")
|
116 |
+
with open(config.DOCUMENTS_PATH, "wb") as file:
|
117 |
+
pickle.dump(docs_processed, file)
|
118 |
+
|
119 |
+
return docs_processed
|
120 |
+
|
121 |
+
|
122 |
+
if __name__ == '__main__':
|
123 |
+
docs_processed = get_rag_data()
|
124 |
+
|
125 |
+
bm25 = retriver.create_bm25(docs_processed)
|
126 |
+
|
127 |
+
KNOWLEDGE_VECTOR_DATABASE = retriver.create_vector_db(docs_processed)
|
128 |
+
|
129 |
+
RERANKER = RAGPretrainedModel.from_pretrained(config.CROSS_ENCODER_MODEL)
|
130 |
+
|
131 |
+
rag_generator = rag.RAGAnswerGenerator(
|
132 |
+
docs=docs_processed,
|
133 |
+
bm25=bm25,
|
134 |
+
knowledge_index=KNOWLEDGE_VECTOR_DATABASE,
|
135 |
+
reranker=RERANKER
|
136 |
+
)
|
137 |
+
|
138 |
+
run_app(rag_generator)
|
chunker.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import config
|
2 |
+
from langchain.docstore.document import Document as LangchainDocument
|
3 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
from tqdm.notebook import tqdm
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
|
9 |
+
def split_documents(chunk_size: int, knowledge_base: List[LangchainDocument]) -> List[LangchainDocument]:
|
10 |
+
"""
|
11 |
+
Split documents into chunks of maximum size `chunk_size` tokens and return a list of documents.
|
12 |
+
"""
|
13 |
+
MARKDOWN_SEPARATORS = [
|
14 |
+
"\n#{1,6} ",
|
15 |
+
"```\n",
|
16 |
+
"\n\\*\\*\\*+\n",
|
17 |
+
"\n---+\n",
|
18 |
+
"\n___+\n",
|
19 |
+
"\n\n",
|
20 |
+
"\n",
|
21 |
+
" ",
|
22 |
+
"",
|
23 |
+
]
|
24 |
+
|
25 |
+
text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
|
26 |
+
AutoTokenizer.from_pretrained(config.EMBEDDING_MODEL_NAME),
|
27 |
+
chunk_size=chunk_size,
|
28 |
+
chunk_overlap=int(chunk_size / 10),
|
29 |
+
add_start_index=True,
|
30 |
+
strip_whitespace=True,
|
31 |
+
separators=MARKDOWN_SEPARATORS,
|
32 |
+
)
|
33 |
+
|
34 |
+
docs_processed = []
|
35 |
+
for doc in tqdm(knowledge_base):
|
36 |
+
docs_processed += text_splitter.split_documents([doc])
|
37 |
+
|
38 |
+
unique_texts = {}
|
39 |
+
docs_processed_unique = []
|
40 |
+
for doc in docs_processed:
|
41 |
+
if doc.page_content not in unique_texts:
|
42 |
+
unique_texts[doc.page_content] = True
|
43 |
+
docs_processed_unique.append(doc)
|
44 |
+
|
45 |
+
return docs_processed_unique # , docs_processed
|
config.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
RAG_PROMPT = """
|
5 |
+
You are an advanced Retrieval-Augmented Generation (RAG) Assistant.
|
6 |
+
|
7 |
+
Your task is to answer user questions based only on the provided documents. Use the context from the documents to generate a response.
|
8 |
+
|
9 |
+
**Guidelines:**
|
10 |
+
1. **Always cite sources**: When information is derived from a document, reference it by citing the chunk number in square brackets, e.g., [Chunk 1], where relevant information is used.
|
11 |
+
2. If the answer cannot be determined from the provided documents, state: "The answer cannot be determined from the provided documents."
|
12 |
+
3. After each answer, provide a numbered list of the retrieved chunks.
|
13 |
+
|
14 |
+
Please follow these instructions to generate accurate and well-cited answers based on the documents.
|
15 |
+
"""
|
16 |
+
LLM_ONLY_PROMPT = """You are an Assistant. If no documents are retrieved, answer the question based on general knowledge."""
|
17 |
+
|
18 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
19 |
+
os.environ['GROQ_API_KEY'] = "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"
|
20 |
+
DB_PATH = "vector_database.faiss"
|
21 |
+
BM25_PATH = "bm25_index.pkl"
|
22 |
+
DOCUMENTS_PATH = "processed_documents.pkl"
|
23 |
+
EMBEDDING_MODEL_NAME = "thenlper/gte-small"
|
24 |
+
CROSS_ENCODER_MODEL = "colbert-ir/colbertv2.0"
|
25 |
+
|
26 |
+
AVAILABLE_DATASET_CONFIGS = [
|
27 |
+
'2024-11'
|
28 |
+
]
|
rag.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List, Tuple
|
2 |
+
from langchain.docstore.document import Document as LangchainDocument
|
3 |
+
from rank_bm25 import BM25Okapi
|
4 |
+
from langchain_community.vectorstores import FAISS
|
5 |
+
from ragatouille import RAGPretrainedModel
|
6 |
+
from litellm import completion
|
7 |
+
import os
|
8 |
+
import retriver
|
9 |
+
import config
|
10 |
+
|
11 |
+
|
12 |
+
class RAGAnswerGenerator:
|
13 |
+
def __init__(self, docs: List[LangchainDocument], bm25: BM25Okapi, knowledge_index: FAISS, reranker: Optional[RAGPretrainedModel] = None):
|
14 |
+
self.bm25 = bm25
|
15 |
+
self.knowledge_index = knowledge_index
|
16 |
+
self.docs = docs
|
17 |
+
self.reranker = reranker
|
18 |
+
self.llm_key = os.environ['GROQ_API_KEY']
|
19 |
+
|
20 |
+
def retrieve_documents(
|
21 |
+
self,
|
22 |
+
question: str,
|
23 |
+
num_retrieved_docs: int,
|
24 |
+
bm_25_flag: bool,
|
25 |
+
semantic_flag: bool
|
26 |
+
) -> List[str]:
|
27 |
+
print("=> Retrieving documents...")
|
28 |
+
relevant_docs = []
|
29 |
+
|
30 |
+
if bm_25_flag or semantic_flag:
|
31 |
+
result = retriver.search(
|
32 |
+
self.docs,
|
33 |
+
self.bm25,
|
34 |
+
self.knowledge_index,
|
35 |
+
question,
|
36 |
+
use_bm25=bm_25_flag,
|
37 |
+
use_semantic_search=semantic_flag,
|
38 |
+
top_k=num_retrieved_docs
|
39 |
+
)
|
40 |
+
if bm_25_flag and semantic_flag:
|
41 |
+
relevant_docs = [doc.page_content for doc in result]
|
42 |
+
return relevant_docs
|
43 |
+
elif bm_25_flag:
|
44 |
+
relevant_docs = result
|
45 |
+
return relevant_docs
|
46 |
+
elif semantic_flag:
|
47 |
+
relevant_docs = [doc.page_content for doc in result]
|
48 |
+
return relevant_docs
|
49 |
+
|
50 |
+
|
51 |
+
def rerank_documents(self, question: str, documents: List[str], num_docs_final: int) -> List[str]:
|
52 |
+
if self.reranker and documents:
|
53 |
+
print("=> Reranking documents...")
|
54 |
+
reranked_docs = self.reranker.rerank(question, documents, k=num_docs_final)
|
55 |
+
return [doc["content"] for doc in reranked_docs]
|
56 |
+
return documents[:num_docs_final]
|
57 |
+
|
58 |
+
def format_context(self, documents: List[str]) -> str:
|
59 |
+
if not documents:
|
60 |
+
return "No retrieved documents available."
|
61 |
+
return "\n".join([f"[{i + 1}] {doc}" for i, doc in enumerate(documents)])
|
62 |
+
|
63 |
+
def generate_answer(
|
64 |
+
self,
|
65 |
+
question: str,
|
66 |
+
context: str,
|
67 |
+
temperature: float,
|
68 |
+
) -> str:
|
69 |
+
print("=> Generating answer...")
|
70 |
+
if context.strip() == "No retrieved documents available.":
|
71 |
+
response = completion(
|
72 |
+
model="groq/llama3-8b-8192",
|
73 |
+
messages=[
|
74 |
+
{"role": "system", "content": config.LLM_ONLY_PROMPT},
|
75 |
+
{"role": "user", "content": f"Question: {question}"}
|
76 |
+
],
|
77 |
+
api_key=self.llm_key,
|
78 |
+
temperature=temperature
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
response = completion(
|
82 |
+
model="groq/llama3-8b-8192",
|
83 |
+
messages=[
|
84 |
+
{"role": "system", "content": config.RAG_PROMPT},
|
85 |
+
{"role": "user", "content": f""" Context: {context} Question: {question} """}
|
86 |
+
],
|
87 |
+
api_key=self.llm_key,
|
88 |
+
temperature=temperature
|
89 |
+
)
|
90 |
+
return response.get("choices", [{}])[0].get("message", {}).get("content", "No response content found")
|
91 |
+
|
92 |
+
def answer(self, question: str, temperature: float, num_retrieved_docs: int = 30, num_docs_final: int = 5, bm_25_flag=True, semantic_flag=True) -> Tuple[str, List[str]]:
|
93 |
+
relevant_docs = self.retrieve_documents(question, num_retrieved_docs, bm_25_flag, semantic_flag)
|
94 |
+
print(len(relevant_docs))
|
95 |
+
relevant_docs = self.rerank_documents(question, relevant_docs, num_docs_final)
|
96 |
+
print(len(relevant_docs))
|
97 |
+
context = self.format_context(relevant_docs)
|
98 |
+
answer = self.generate_answer(question, context, temperature)
|
99 |
+
document_list = [f"[{i + 1}] {doc}" for i, doc in enumerate(relevant_docs)] if relevant_docs else []
|
100 |
+
return answer, document_list
|
retriver.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import config
|
2 |
+
import utils
|
3 |
+
from nltk.tokenize import word_tokenize
|
4 |
+
from typing import List
|
5 |
+
import nltk
|
6 |
+
import torch
|
7 |
+
import pickle
|
8 |
+
from langchain.docstore.document import Document as LangchainDocument
|
9 |
+
from rank_bm25 import BM25Okapi
|
10 |
+
from langchain_community.vectorstores import FAISS
|
11 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
12 |
+
from langchain_community.vectorstores.utils import DistanceStrategy
|
13 |
+
from langchain.retrievers import EnsembleRetriever
|
14 |
+
from langchain_community.retrievers import BM25Retriever
|
15 |
+
import os
|
16 |
+
|
17 |
+
|
18 |
+
def create_vector_db(docs: List[LangchainDocument]):
|
19 |
+
db_path: str = config.DB_PATH
|
20 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
+
|
22 |
+
embedding_model = HuggingFaceEmbeddings(
|
23 |
+
model_name=config.EMBEDDING_MODEL_NAME,
|
24 |
+
multi_process=True,
|
25 |
+
model_kwargs={"device": device},
|
26 |
+
encode_kwargs={"normalize_embeddings": True},
|
27 |
+
)
|
28 |
+
|
29 |
+
if os.path.exists(db_path):
|
30 |
+
print(f"Завантаження векторної бази даних з {db_path}")
|
31 |
+
knowledge_vector_database = FAISS.load_local(
|
32 |
+
db_path,
|
33 |
+
embedding_model,
|
34 |
+
allow_dangerous_deserialization=True
|
35 |
+
)
|
36 |
+
return knowledge_vector_database
|
37 |
+
elif docs is not None:
|
38 |
+
print("Створення нової векторної бази даних")
|
39 |
+
knowledge_vector_database = FAISS.from_documents(
|
40 |
+
docs, embedding_model, distance_strategy=DistanceStrategy.COSINE
|
41 |
+
)
|
42 |
+
knowledge_vector_database.save_local(db_path)
|
43 |
+
print(f"Векторна база даних збережена в {db_path}")
|
44 |
+
return knowledge_vector_database
|
45 |
+
else:
|
46 |
+
raise ValueError(
|
47 |
+
"""Documents are missing!
|
48 |
+
Please load the documents and set get_data=True in app.py."""
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
def create_bm25(docs: List[LangchainDocument]):
|
54 |
+
bm25_path: str = config.BM25_PATH
|
55 |
+
if os.path.exists(bm25_path):
|
56 |
+
print(f"Завантаження BM25 індексу з {bm25_path}")
|
57 |
+
with open(bm25_path, "rb") as file:
|
58 |
+
bm25 = pickle.load(file)
|
59 |
+
return bm25
|
60 |
+
elif docs is not None:
|
61 |
+
print("Створення нового BM25 індексу")
|
62 |
+
tokenized_docs = [word_tokenize(doc.page_content.lower()) for doc in docs]
|
63 |
+
bm25 = BM25Okapi(tokenized_docs)
|
64 |
+
with open(bm25_path, "wb") as file:
|
65 |
+
pickle.dump(bm25, file)
|
66 |
+
print(f"BM25 індекс збережено в {bm25_path}")
|
67 |
+
return bm25
|
68 |
+
else:
|
69 |
+
raise ValueError(
|
70 |
+
"""Documents are missing!
|
71 |
+
Please load the documents and set get_data=True in app.py."""
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
def search(docs_processed, bm_25: BM25Okapi, vector_db: FAISS, query, top_k, use_bm25=True, use_semantic_search=True):
|
76 |
+
if use_bm25 and use_semantic_search:
|
77 |
+
bm25_retriever = BM25Retriever.from_documents(docs_processed)
|
78 |
+
bm25_retriever.k = top_k
|
79 |
+
faiss_retriever = vector_db.as_retriever(search_kwargs={"k": top_k})
|
80 |
+
ensemble_retriever = EnsembleRetriever(
|
81 |
+
retrievers=[bm25_retriever, faiss_retriever],
|
82 |
+
weights=[0.5, 0.5]
|
83 |
+
)
|
84 |
+
result = ensemble_retriever.invoke(query)
|
85 |
+
return result
|
86 |
+
|
87 |
+
elif use_bm25:
|
88 |
+
tokenized_query = word_tokenize(query.lower())
|
89 |
+
result = bm_25.get_top_n(tokenized_query, [doc.page_content for doc in docs_processed], n=top_k)
|
90 |
+
|
91 |
+
elif use_semantic_search:
|
92 |
+
result = vector_db.similarity_search(query, k=top_k)
|
93 |
+
else:
|
94 |
+
result = []
|
95 |
+
return result
|
96 |
+
|
utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import config
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
def align_features(dataset):
|
7 |
+
def fix_authors(example):
|
8 |
+
if not isinstance(example["authors"], list):
|
9 |
+
return {"authors": [example["authors"]] if example["authors"] else []}
|
10 |
+
return example
|
11 |
+
|
12 |
+
return dataset.map(fix_authors)
|