ppsingh commited on
Commit
5254f40
·
verified ·
1 Parent(s): 277e49a

Update app/retriever.py

Browse files
Files changed (1) hide show
  1. app/retriever.py +14 -6
app/retriever.py CHANGED
@@ -3,6 +3,8 @@ from qdrant_client.http import models as rest
3
  from langchain.schema import Document
4
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
5
  from langchain.retrievers.document_compressors import CrossEncoderReranker
 
 
6
  import logging
7
  import os
8
  from .utils import getconfig
@@ -220,15 +222,21 @@ def get_context(
220
  search_kwargs = {
221
  "model_name": config.get("embeddings", "MODEL_NAME")
222
  }
223
-
 
 
 
 
 
 
224
  # filter support for QdrantVectorStore
225
- if isinstance(vectorstore, QdrantVectorStore):
226
- filter_obj = create_filter(reports, sources, subtype, year)
227
- if filter_obj:
228
- search_kwargs["filter"] = filter_obj
229
 
230
  # Perform initial retrieval
231
- retrieved_docs = vectorstore.search(query, top_k, **search_kwargs)
232
 
233
  logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
234
 
 
3
  from langchain.schema import Document
4
  from langchain_community.cross_encoders import HuggingFaceCrossEncoder
5
  from langchain.retrievers.document_compressors import CrossEncoderReranker
6
+ from sentence_transformers import SentenceTransformer
7
+ model = SentenceTransformer('BAAI/bge-m3')
8
  import logging
9
  import os
10
  from .utils import getconfig
 
222
  search_kwargs = {
223
  "model_name": config.get("embeddings", "MODEL_NAME")
224
  }
225
+ model = SentenceTransformer(config.get("embeddings", "MODEL_NAME"))
226
+ query_vector = model.encode(query).tolist()
227
+ retrieved_docs = client.search(
228
+ collection_name="EUDR",
229
+ query_vector=query_vector,
230
+ limit=top_k,
231
+ with_payload=True)
232
  # filter support for QdrantVectorStore
233
+ #if isinstance(vectorstore, QdrantVectorStore):
234
+ # filter_obj = create_filter(reports, sources, subtype, year)
235
+ # if filter_obj:
236
+ # search_kwargs["filter"] = filter_obj
237
 
238
  # Perform initial retrieval
239
+ #retrieved_docs = vectorstore.search(query, top_k,)
240
 
241
  logging.info(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
242