Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
search pipeline updated
Browse files- RAG/colpali.py +2 -22
- RAG/rag_DocumentSearcher.py +17 -2
RAG/colpali.py
CHANGED
|
@@ -206,39 +206,23 @@ def colpali_search_rerank(query):
|
|
| 206 |
add_score = 0
|
| 207 |
|
| 208 |
for index,i in enumerate(query_token_vectors):
|
| 209 |
-
#token = vocab_dict[str(token_ids[index])]
|
| 210 |
-
#if(token!='[SEP]' and token!='[CLS]'):
|
| 211 |
query_token_vector = np.array(i)
|
| 212 |
-
#print("query token: "+token)
|
| 213 |
-
#print("-----------------")
|
| 214 |
scores = []
|
| 215 |
for m in with_s:
|
| 216 |
-
#m_arr = m.split("-")
|
| 217 |
-
#if(m_arr[-1]!='[SEP]' and m_arr[-1]!='[CLS]'):
|
| 218 |
-
#print("document token: "+m_arr[3])
|
| 219 |
doc_token_vector = np.array(m['page_sub_vector'])
|
| 220 |
score = np.dot(query_token_vector,doc_token_vector)
|
| 221 |
scores.append(score)
|
| 222 |
-
|
| 223 |
-
|
| 224 |
scores.sort(reverse=True)
|
| 225 |
max_score = scores[0]
|
| 226 |
add_score+=max_score
|
| 227 |
-
#max_score_dict_list.append(newlist[0])
|
| 228 |
-
#print(newlist[0])
|
| 229 |
-
#max_score_dict_list_sorted = sorted(max_score_dict_list, key=lambda d: d['score'], reverse=True)
|
| 230 |
-
#print(max_score_dict_list_sorted)
|
| 231 |
-
# print(add_score)
|
| 232 |
doc["total_score"] = add_score
|
| 233 |
-
#doc['max_score_dict_list_sorted'] = max_score_dict_list_sorted
|
| 234 |
final_docs.append(doc)
|
| 235 |
final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True)
|
| 236 |
final_docs_sorted_20.append(final_docs_sorted[:20])
|
| 237 |
img = "/home/user/app/vs/"+final_docs_sorted_20[0][0]['image']
|
| 238 |
ans = generate_ans(img,query)
|
| 239 |
images_highlighted = [{'file':img}]
|
| 240 |
-
# if(st.session_state.show_columns == True):
|
| 241 |
-
# images_highlighted = img_highlight(img,query_token_vectors,result['query_tokens'])
|
| 242 |
st.session_state.top_img = img
|
| 243 |
st.session_state.query_token_vectors = query_token_vectors
|
| 244 |
st.session_state.query_tokens = result['query_tokens']
|
|
@@ -312,12 +296,8 @@ def img_highlight(img,batch_queries,query_tokens):
|
|
| 312 |
# # Get the similarity map for our (only) input image
|
| 313 |
similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
|
| 314 |
|
| 315 |
-
print(f"Similarity map shape: (query_length, n_patches_x, n_patches_y) = {tuple(similarity_maps.shape)}")
|
| 316 |
-
print(query_tokens)
|
| 317 |
query_tokens_from_model = query_tokens[0]['tokens']
|
| 318 |
-
|
| 319 |
-
print(type(query_tokens_from_model))
|
| 320 |
-
|
| 321 |
plots = plot_all_similarity_maps(
|
| 322 |
image=image,
|
| 323 |
query_tokens=query_tokens_from_model,
|
|
|
|
| 206 |
add_score = 0
|
| 207 |
|
| 208 |
for index,i in enumerate(query_token_vectors):
|
|
|
|
|
|
|
| 209 |
query_token_vector = np.array(i)
|
|
|
|
|
|
|
| 210 |
scores = []
|
| 211 |
for m in with_s:
|
|
|
|
|
|
|
|
|
|
| 212 |
doc_token_vector = np.array(m['page_sub_vector'])
|
| 213 |
score = np.dot(query_token_vector,doc_token_vector)
|
| 214 |
scores.append(score)
|
| 215 |
+
|
|
|
|
| 216 |
scores.sort(reverse=True)
|
| 217 |
max_score = scores[0]
|
| 218 |
add_score+=max_score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
doc["total_score"] = add_score
|
|
|
|
| 220 |
final_docs.append(doc)
|
| 221 |
final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True)
|
| 222 |
final_docs_sorted_20.append(final_docs_sorted[:20])
|
| 223 |
img = "/home/user/app/vs/"+final_docs_sorted_20[0][0]['image']
|
| 224 |
ans = generate_ans(img,query)
|
| 225 |
images_highlighted = [{'file':img}]
|
|
|
|
|
|
|
| 226 |
st.session_state.top_img = img
|
| 227 |
st.session_state.query_token_vectors = query_token_vectors
|
| 228 |
st.session_state.query_tokens = result['query_tokens']
|
|
|
|
| 296 |
# # Get the similarity map for our (only) input image
|
| 297 |
similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y)
|
| 298 |
|
|
|
|
|
|
|
| 299 |
query_tokens_from_model = query_tokens[0]['tokens']
|
| 300 |
+
|
|
|
|
|
|
|
| 301 |
plots = plot_all_similarity_maps(
|
| 302 |
image=image,
|
| 303 |
query_tokens=query_tokens_from_model,
|
RAG/rag_DocumentSearcher.py
CHANGED
|
@@ -189,23 +189,38 @@ def query_(awsauth,inputs, session_id,search_types):
|
|
| 189 |
# query_sparse = sparse_["inference_results"][0]["output"][0]["dataAsMap"]["response"][0]
|
| 190 |
|
| 191 |
hits = []
|
| 192 |
-
if(num_queries>1
|
| 193 |
s_pipeline_url = host + s_pipeline_path
|
| 194 |
r = requests.put(s_pipeline_url, auth=awsauth, json=s_pipeline_payload, headers=headers)
|
| 195 |
path = st.session_state.input_index+"/_search?search_pipeline=rag-search-pipeline"
|
| 196 |
else:
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
| 198 |
url = host+path
|
| 199 |
if(len(hybrid_payload["query"]["hybrid"]["queries"])==1):
|
| 200 |
single_query = hybrid_payload["query"]["hybrid"]["queries"][0]
|
| 201 |
del hybrid_payload["query"]["hybrid"]
|
| 202 |
hybrid_payload["query"] = single_query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
|
| 204 |
response_ = json.loads(r.text)
|
| 205 |
print(response_)
|
| 206 |
hits = response_['hits']['hits']
|
| 207 |
|
| 208 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
|
| 210 |
response_ = json.loads(r.text)
|
| 211 |
hits = response_['hits']['hits']
|
|
|
|
| 189 |
# query_sparse = sparse_["inference_results"][0]["output"][0]["dataAsMap"]["response"][0]
|
| 190 |
|
| 191 |
hits = []
|
| 192 |
+
if(num_queries>1):
|
| 193 |
s_pipeline_url = host + s_pipeline_path
|
| 194 |
r = requests.put(s_pipeline_url, auth=awsauth, json=s_pipeline_payload, headers=headers)
|
| 195 |
path = st.session_state.input_index+"/_search?search_pipeline=rag-search-pipeline"
|
| 196 |
else:
|
| 197 |
+
if(input_is_rerank):
|
| 198 |
+
path = st.session_state.input_index+"/_search?search_pipeline=rerank_pipeline_rag"
|
| 199 |
+
else:
|
| 200 |
+
path = st.session_state.input_index+"/_search"
|
| 201 |
url = host+path
|
| 202 |
if(len(hybrid_payload["query"]["hybrid"]["queries"])==1):
|
| 203 |
single_query = hybrid_payload["query"]["hybrid"]["queries"][0]
|
| 204 |
del hybrid_payload["query"]["hybrid"]
|
| 205 |
hybrid_payload["query"] = single_query
|
| 206 |
+
if(st.session_state.input_is_rerank):
|
| 207 |
+
hybrid_payload["ext"] = {"rerank": {
|
| 208 |
+
"query_context": {
|
| 209 |
+
"query_text": question
|
| 210 |
+
}
|
| 211 |
+
}}
|
| 212 |
r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
|
| 213 |
response_ = json.loads(r.text)
|
| 214 |
print(response_)
|
| 215 |
hits = response_['hits']['hits']
|
| 216 |
|
| 217 |
else:
|
| 218 |
+
if(st.session_state.input_is_rerank):
|
| 219 |
+
hybrid_payload["ext"] = {"rerank": {
|
| 220 |
+
"query_context": {
|
| 221 |
+
"query_text": question
|
| 222 |
+
}
|
| 223 |
+
}}
|
| 224 |
r = requests.get(url, auth=awsauth, json=hybrid_payload, headers=headers)
|
| 225 |
response_ = json.loads(r.text)
|
| 226 |
hits = response_['hits']['hits']
|