File size: 7,757 Bytes
402e33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# # from langchain_chroma import Chroma
# # from langchain_openai import OpenAIEmbeddings
# from langchain_community.embeddings import HuggingFaceEmbeddings
# from utils.vector_store import get_vector_store
# import os


# # Define different embedding model options
# # EMBEDDING_CONFIGS = {
# #     "Accuracy (OpenAI text-embedding-3-large)": OpenAIEmbeddings(model="text-embedding-3-large"),
# #     "Performance (OpenAI text-embedding-3-small)": OpenAIEmbeddings(model="text-embedding-3-small"),
# #     "Instruction-based (HuggingFace bge-large-en)": HuggingFaceEmbeddings(model_name="BAAI/bge-large-en"),
# #     "QA Optimized (HuggingFace e5-large-v2)": HuggingFaceEmbeddings(model_name="intfloat/e5-large-v2"),
# # }


# EMBEDDING_CONFIGS = {
#     "General-purpose (bge-large-en)": HuggingFaceEmbeddings(
#         model_name="BAAI/bge-large-en",
#         model_kwargs={"device": "cpu"},   # or "cuda" if you have GPU
#         encode_kwargs={"normalize_embeddings": True}
#     ),
#     "Fast & lightweight (bge-small-en)": HuggingFaceEmbeddings(
#         model_name="BAAI/bge-small-en",
#         model_kwargs={"device": "cpu"},
#         encode_kwargs={"normalize_embeddings": True}
#     ),
#     "QA optimized (e5-large-v2)": HuggingFaceEmbeddings(
#         model_name="intfloat/e5-large-v2",
#         model_kwargs={"device": "cpu"},
#         encode_kwargs={"normalize_embeddings": True}
#     ),
#     "Instruction-tuned (instructor-large)": HuggingFaceEmbeddings(
#         model_name="hkunlp/instructor-large",
#         model_kwargs={"device": "cpu"},
#         encode_kwargs={"normalize_embeddings": True}
#     ),
# }

# # Default vector store path
# VECTOR_STORE_PATH = "./vector_stores/mes_db"


# def test_retriever_with_embeddings(query: str, embedding_model, k: int = 3):
#     """Retrieve documents using a specific embedding model."""
#     vector_store = get_vector_store(
#         persist_directory=VECTOR_STORE_PATH,
#         embedding=embedding_model
#     )
#     retriever = vector_store.as_retriever(search_kwargs={"k": k})
#     docs = retriever.get_relevant_documents(query)

#     # Deduplicate based on page_content
#     seen = set()
#     unique_docs = []
#     for doc in docs:
#         if doc.page_content not in seen:
#             seen.add(doc.page_content)
#             unique_docs.append(doc)

#     return unique_docs


# def compare_embeddings(query: str, k: int = 3):
#     print(f"\n=== Comparing embeddings for: '{query}' ===\n")

#     for name, embedding_model in EMBEDDING_CONFIGS.items():
#         try:
#             print(f"πŸ” {name}:")
#             print("-" * 50)
#             docs = test_retriever_with_embeddings(query, embedding_model, k)
#             for i, doc in enumerate(docs, 1):
#                 source = doc.metadata.get("source", "unknown")
#                 page = doc.metadata.get("page", "N/A")
#                 preview = doc.page_content[:300]
#                 if len(doc.page_content) > 300:
#                     preview += "..."
#                 print(f"--- Chunk #{i} ---")
#                 print(f"Source: {source} | Page: {page}")
#                 print(preview)
#                 print()
#             print("\n" + "=" * 60 + "\n")
#         except Exception as e:
#             print(f"❌ Error with {name}: {e}\n")


# if __name__ == "__main__":
#     print("Embedding Model Benchmark Tool")
#     print("\nType 'compare: <question>' to compare all embeddings")
#     print("Type 'exit' to quit\n")

#     while True:
#         user_input = input("\nEnter your question: ").strip()

#         if user_input.lower() == "exit":
#             break
#         elif user_input.lower().startswith("compare: "):
#             query = user_input[9:]
#             compare_embeddings(query)
#         else:
#             print("Please use the format: compare: <question>")


from utils.vector_store import get_vector_store


def test_retriever_with_embeddings(query: str, embedding_model, k: int = 3, vector_store_path="./chroma_db"):
    """Test retriever with a specific embedding model and vector store"""
    vector_store = get_vector_store(
        persist_directory=vector_store_path, embedding=embedding_model)
    retriever = vector_store.as_retriever(search_kwargs={"k": k})
    docs = retriever.get_relevant_documents(query)

    # Deduplicate based on page_content
    seen = set()
    unique_docs = []
    for doc in docs:
        if doc.page_content not in seen:
            seen.add(doc.page_content)
            unique_docs.append(doc)

    print(f"\nUsing vector store: {vector_store_path}")
    print(f"Top {len(unique_docs)} unique chunks retrieved for: '{query}'\n")

    for i, doc in enumerate(unique_docs, 1):
        source = doc.metadata.get("source", "unknown")
        page = doc.metadata.get("page", "N/A")
        print(f"--- Chunk #{i} ---")
        print(f"Source: {source} | Page: {page}")
        preview = doc.page_content[:300]
        if len(doc.page_content) > 300:
            preview += "..."
        print(preview)
        print()


def compare_retrievers_with_embeddings(query: str, embedding_model, k: int = 3):
    """Compare results from different vector stores using the same embedding model"""
    stores = {
        "MES Manual": "./vector_stores/mes_db",
        "Technical Docs": "./vector_stores/tech_db",
        "General Docs": "./vector_stores/general_db"
    }

    print(f"\n=== Comparing retrievers for: '{query}' ===\n")

    for store_name, store_path in stores.items():
        try:
            print(f"πŸ” {store_name}:")
            print("-" * 50)
            test_retriever_with_embeddings(
                query, embedding_model, k=k, vector_store_path=store_path)
            print("\n" + "="*60 + "\n")
        except Exception as e:
            print(f"❌ Could not access {store_name}: {e}\n")


if __name__ == "__main__":
    from embedding_config import EMBEDDING_CONFIGS

    print("Multi-Vector Store RAG Tester (with Embeddings)")
    print("\nAvailable commands:")
    print("  - Enter a question to test default store")
    print("  - Type 'mes: <question>' for MES manual")
    print("  - Type 'tech: <question>' for technical docs")
    print("  - Type 'general: <question>' for general docs")
    print("  - Type 'compare: <question>' to compare all stores")
    print("  - Type 'exit' to quit")

    # Choose embedding model at start
    print("\nAvailable Embedding Models:")
    for i, name in enumerate(EMBEDDING_CONFIGS.keys(), 1):
        print(f"  {i}. {name}")
    choice = int(input("Select embedding model number: ").strip())
    embedding_model = list(EMBEDDING_CONFIGS.values())[choice - 1]

    while True:
        user_input = input("\nEnter your question: ").strip()

        if user_input.lower() == "exit":
            break
        elif user_input.lower().startswith("mes: "):
            query = user_input[5:]
            test_retriever_with_embeddings(
                query, embedding_model, vector_store_path="./vector_stores/mes_db")
        elif user_input.lower().startswith("tech: "):
            query = user_input[6:]
            test_retriever_with_embeddings(
                query, embedding_model, vector_store_path="./vector_stores/tech_db")
        elif user_input.lower().startswith("general: "):
            query = user_input[9:]
            test_retriever_with_embeddings(
                query, embedding_model, vector_store_path="./vector_stores/general_db")
        elif user_input.lower().startswith("compare: "):
            query = user_input[9:]
            compare_retrievers_with_embeddings(query, embedding_model)
        else:
            test_retriever_with_embeddings(
                user_input, embedding_model)  # Default store