File size: 4,690 Bytes
1b04b96
 
 
2d5dee0
1d3e5ce
fdc80c8
f78495c
 
01c5a73
 
 
 
147795a
 
 
f78495c
2d5dee0
 
1d3e5ce
ece1395
01c5a73
 
ece1395
fdc80c8
 
1d3e5ce
 
fdc80c8
1d3e5ce
 
 
01c5a73
1d3e5ce
 
 
 
 
 
01c5a73
ece1395
1d3e5ce
 
01c5a73
1d3e5ce
 
 
 
ece1395
01c5a73
ece1395
01c5a73
 
ece1395
 
01c5a73
 
1d3e5ce
f78495c
1d3e5ce
ece1395
 
 
 
 
 
 
 
 
 
 
 
 
 
1b04b96
fdc80c8
 
 
c14a20a
fdc80c8
 
1d3e5ce
fdc80c8
 
2d5dee0
fdc80c8
1b04b96
fdc80c8
1b04b96
fdc80c8
1b04b96
fdc80c8
1b04b96
 
 
ece1395
1b04b96
 
 
 
 
99afa50
 
 
c14a20a
99afa50
c14a20a
99afa50
 
 
 
 
c14a20a
 
f78495c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
import numpy as np
from langchain.schema import Document
import faiss
from rank_bm25 import BM25Okapi
from data_processing import embedding_model
from sentence_transformers import CrossEncoder

#import string
# import nltk
# nltk.download('punkt')  
# nltk.download('punkt_tab')

from nltk.tokenize import word_tokenize

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

retrieved_docs = None

# Tokenize the documents and remove punctuation
# def preprocess(doc):
#     return [word.lower() for word in word_tokenize(doc) if word not in string.punctuation]

def retrieve_documents_hybrid(query, q_dataset, top_k=5):
    with open( f"data_local/{q_dataset}_chunked_docs.json", "r") as f:
        chunked_documents = json.load(f)  # Contains all documents for this dataset
    
    faiss_index_path = f"data_local/{q_dataset}_quantized.faiss"
    index = faiss.read_index(faiss_index_path)

    # Tokenize documents for BM25
    tokenized_docs = [doc.split() for doc in chunked_documents]
    bm25 = BM25Okapi(tokenized_docs)

    query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
    query_embedding = query_embedding.reshape(1, -1)

    # FAISS Search
    _, faiss_indices = index.search(query_embedding, top_k)
    faiss_docs = [chunked_documents[i] for i in faiss_indices[0]]

    # BM25 Search  
    tokenized_query = query.split() #preprocess(query)
    bm25_scores = bm25.get_scores(tokenized_query)
    bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
    bm25_docs = [chunked_documents[i] for i in bm25_top_indices]

    # Combine FAISS + BM25 scores and retrieve docs
    # combined_results = set(bm25_top_indices).union(set(faiss_indices[0])) 
        
    # combined_scores = rerank_docs_bm25faiss_scores(combined_results,bm25_scores, faiss_distances,faiss_indices)
    # reranked_docs = [chunked_documents[result[0]] for result in combined_scores[:top_k]]

    # Merge FAISS + BM25 Results and re-rank
    retrieved_docs = list(set(faiss_docs + bm25_docs))[:top_k]
    reranked_docs = rerank_documents(query, retrieved_docs)

    return reranked_docs

def rerank_docs_bm25faiss_scores(combined_results_,bm25_scores_, faiss_distances_,faiss_indices_):
    final_results = []
    for idx in combined_results_:
        # Combine BM25 score and FAISS score for ranking (this could be more sophisticated)
        bm25_score = bm25_scores_[idx]
        faiss_score = 1 / (1 + faiss_distances_[0][np.where(faiss_indices_[0] == idx)])  # Inverse distance for relevance
        final_results.append((idx, bm25_score, faiss_score))

    # Sort final results by combined score (you can adjust the ranking strategy here)
    final_results.sort(key=lambda x: (x[1] + x[2]), reverse=True)

    return final_results
    

# Retrieval Function
# def retrieve_documents(query, top_k=5):
#     query_dataset = find_query_dataset(query)
#     #index, chunk_docs = load_data_from_faiss(query)

#     with open( f"data_local/{query_dataset}_chunked_docs.json", "r") as f:
#         documents = json.load(f)  # Contains all documents for this dataset

#     faiss_index_path = f"data_local/{query_dataset}_quantized.faiss"
#     index = faiss.read_index(faiss_index_path)

#     query_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)

#     _, nearest_indices = index.search(query_embedding, top_k)

#     retrieved_docs = [Document(page_content=documents[i]) for i in nearest_indices[0]]

#     return retrieved_docs

def remove_duplicate_documents(documents):
    unique_documents = []
    seen_documents = set() 
    for doc in documents:
        doc_content = doc.page_content
        if doc_content not in seen_documents:
            unique_documents.append(doc)
            seen_documents.add(doc_content)
    return unique_documents

def find_query_dataset(query):
    index = faiss.read_index("data_local/question_quantized.faiss")

    with open("data_local/dataset_mapping.json", "r") as f:
        dataset_names = json.load(f)

    question_embedding = np.array(embedding_model.embed_documents([query]), dtype=np.float32)
    _, nearest_index = index.search(question_embedding, 1)  
    best_dataset = dataset_names[nearest_index[0][0]]
    return best_dataset

def rerank_documents(query, retrieved_docs):
    doc_texts = [doc for doc in retrieved_docs]
    scores = reranker.predict([[query, doc] for doc in doc_texts])
    ranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
    return ranked_docs[:5]  # Return top k most relevant