Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
rerank model
Browse files- pages/Semantic_Search.py +66 -72
- semantic_search/query_rewrite.py +3 -87
pages/Semantic_Search.py
CHANGED
|
@@ -747,83 +747,77 @@ def render_answer(answer,index):
|
|
| 747 |
col_1, col_2,col_3 = st.columns([70,10,20])
|
| 748 |
i = 0
|
| 749 |
filter_out = 0
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 766 |
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
if("highlight" in ans and 'Keyword Search' in st.session_state.input_searchType):
|
| 773 |
-
test_strs = ans["highlight"]
|
| 774 |
-
tag = "em"
|
| 775 |
-
res__ = []
|
| 776 |
-
for test_str in test_strs:
|
| 777 |
-
start_idx = test_str.find("<" + tag + ">")
|
| 778 |
-
|
| 779 |
-
while start_idx != -1:
|
| 780 |
-
end_idx = test_str.find("</" + tag + ">", start_idx)
|
| 781 |
-
if end_idx == -1:
|
| 782 |
-
break
|
| 783 |
-
res__.append(test_str[start_idx+len(tag)+2:end_idx])
|
| 784 |
-
start_idx = test_str.find("<" + tag + ">", end_idx)
|
| 785 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 786 |
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
final_desc = "<p>"
|
| 790 |
-
|
| 791 |
-
for word in desc__:
|
| 792 |
-
if(re.sub('[^A-Za-z0-9]+', '', word) in res__):
|
| 793 |
-
final_desc += "<span style='color:#e28743;font-weight:bold'>"+word+"</span> "
|
| 794 |
-
else:
|
| 795 |
-
final_desc += word + " "
|
| 796 |
-
|
| 797 |
-
final_desc += "</p>"
|
| 798 |
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
if(st.session_state.input_evaluate == "enabled"):
|
| 819 |
-
with st.container(border = False):
|
| 820 |
-
if("relevant" in ans.keys()):
|
| 821 |
-
if(ans['relevant']==True):
|
| 822 |
-
st.write(":white_check_mark:")
|
| 823 |
-
else:
|
| 824 |
-
st.write(":x:")
|
| 825 |
|
| 826 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 827 |
|
| 828 |
with col_3:
|
| 829 |
if(index == len(st.session_state.questions)):
|
|
|
|
| 747 |
col_1, col_2,col_3 = st.columns([70,10,20])
|
| 748 |
i = 0
|
| 749 |
filter_out = 0
|
| 750 |
+
if len(answer) == 0:
|
| 751 |
+
st.write("No results found")
|
| 752 |
+
else:
|
| 753 |
+
for ans in answer:
|
| 754 |
+
if('b5/b5319e00' in ans['image_url'] ):
|
| 755 |
+
filter_out+=1
|
| 756 |
+
continue
|
| 757 |
+
format_ = ans['image_url'].split(".")[-1]
|
| 758 |
+
Image.MAX_IMAGE_PIXELS = 100000000
|
| 759 |
+
width = 500
|
| 760 |
+
height = 500
|
| 761 |
+
with col_1:
|
| 762 |
+
inner_col_1,inner_col_2 = st.columns([8,92])
|
| 763 |
+
with inner_col_2:
|
| 764 |
+
st.image(ans['image_url'].replace("/home/ec2-user/SageMaker/","/home/user/"))
|
| 765 |
+
|
| 766 |
+
if("highlight" in ans and 'Keyword Search' in st.session_state.input_searchType):
|
| 767 |
+
test_strs = ans["highlight"]
|
| 768 |
+
tag = "em"
|
| 769 |
+
res__ = []
|
| 770 |
+
for test_str in test_strs:
|
| 771 |
+
start_idx = test_str.find("<" + tag + ">")
|
| 772 |
+
|
| 773 |
+
while start_idx != -1:
|
| 774 |
+
end_idx = test_str.find("</" + tag + ">", start_idx)
|
| 775 |
+
if end_idx == -1:
|
| 776 |
+
break
|
| 777 |
+
res__.append(test_str[start_idx+len(tag)+2:end_idx])
|
| 778 |
+
start_idx = test_str.find("<" + tag + ">", end_idx)
|
| 779 |
|
| 780 |
+
|
| 781 |
+
desc__ = ans['desc'].split(" ")
|
| 782 |
+
|
| 783 |
+
final_desc = "<p>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 784 |
|
| 785 |
+
for word in desc__:
|
| 786 |
+
if(re.sub('[^A-Za-z0-9]+', '', word) in res__):
|
| 787 |
+
final_desc += "<span style='color:#e28743;font-weight:bold'>"+word+"</span> "
|
| 788 |
+
else:
|
| 789 |
+
final_desc += word + " "
|
| 790 |
|
| 791 |
+
final_desc += "</p>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 792 |
|
| 793 |
+
st.markdown(final_desc,unsafe_allow_html = True)
|
| 794 |
+
else:
|
| 795 |
+
st.write(ans['desc'])
|
| 796 |
+
if("sparse" in ans):
|
| 797 |
+
with st.expander("Expanded document:"):
|
| 798 |
+
sparse_ = dict(sorted(ans['sparse'].items(), key=lambda item: item[1],reverse=True))
|
| 799 |
+
filtered_sparse = dict()
|
| 800 |
+
for key in sparse_:
|
| 801 |
+
if(sparse_[key]>=1.0):
|
| 802 |
+
filtered_sparse[key] = round(sparse_[key], 2)
|
| 803 |
+
st.write(filtered_sparse)
|
| 804 |
+
with st.expander("Document Metadata:",expanded = False):
|
| 805 |
+
st.write(":green[default:]")
|
| 806 |
+
st.json({"category:":ans['category'],"price":str(ans['price']),"gender_affinity":ans['gender_affinity'],"style":ans['style']},expanded = True)
|
| 807 |
+
if("rekog" in ans):
|
| 808 |
+
st.write(":green[enriched:]")
|
| 809 |
+
st.json(ans['rekog'],expanded = True)
|
| 810 |
+
with inner_col_1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 811 |
|
| 812 |
+
if(st.session_state.input_evaluate == "enabled"):
|
| 813 |
+
with st.container(border = False):
|
| 814 |
+
if("relevant" in ans.keys()):
|
| 815 |
+
if(ans['relevant']==True):
|
| 816 |
+
st.write(":white_check_mark:")
|
| 817 |
+
else:
|
| 818 |
+
st.write(":x:")
|
| 819 |
+
|
| 820 |
+
i = i+1
|
| 821 |
|
| 822 |
with col_3:
|
| 823 |
if(index == len(st.session_state.questions)):
|
semantic_search/query_rewrite.py
CHANGED
|
@@ -252,16 +252,6 @@ def get_new_query_res(query):
|
|
| 252 |
query = st.session_state.input_rekog_label
|
| 253 |
if(st.session_state.input_is_rewrite_query == 'enabled'):
|
| 254 |
|
| 255 |
-
# query_struct = query_constructor.invoke(
|
| 256 |
-
# {
|
| 257 |
-
# "query": query
|
| 258 |
-
# }
|
| 259 |
-
# )
|
| 260 |
-
# print("***prompt****")
|
| 261 |
-
# print(prompt)
|
| 262 |
-
# print("******query_struct******")
|
| 263 |
-
# print(query_struct)
|
| 264 |
-
|
| 265 |
res = invoke_models.invoke_llm_model( prompt_.format(query=query,schema = schema) ,False)
|
| 266 |
inter_query = res[7:-3].replace('\\"',"'").replace("\n","")
|
| 267 |
print("inter_query")
|
|
@@ -294,43 +284,8 @@ def get_new_query_res(query):
|
|
| 294 |
draft_new_query['bool']['must'].append(q_dash)
|
| 295 |
else:
|
| 296 |
draft_new_query['bool']['should'].append(q_dash)
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
# q__dash = json.loads(json.dumps(q_).replace('term','match' ))
|
| 300 |
-
# clause = list(q__dash.keys())[0]category
|
| 301 |
-
# long_field = list(q__dash[clause].keys())[0]
|
| 302 |
-
# get_attr = long_field.split(".")[1]
|
| 303 |
-
# q__dash[clause][get_attr] = q__dash[clause][long_field]
|
| 304 |
-
# draft_new_query['bool']['should'].append(q__dash)
|
| 305 |
-
|
| 306 |
-
#print(draft_new_query)
|
| 307 |
-
query_ = draft_new_query#json.loads(json.dumps(opts.visit_structured_query(query_struct)[1]['filter']).replace("must","should"))#.replace("must","should")
|
| 308 |
-
|
| 309 |
-
# if('bool' in query_ and 'should' in query_['bool']):
|
| 310 |
-
# query_['bool']['should'].append({
|
| 311 |
-
# "match": {
|
| 312 |
-
|
| 313 |
-
# "rekog_description_plus_original_description": query
|
| 314 |
-
|
| 315 |
-
# }
|
| 316 |
-
# })
|
| 317 |
-
# else:
|
| 318 |
-
# query_['bool']['should'] = {
|
| 319 |
-
# "match": {
|
| 320 |
-
|
| 321 |
-
# "rekog_description_plus_original_description": query
|
| 322 |
-
|
| 323 |
-
# }
|
| 324 |
-
# }
|
| 325 |
-
|
| 326 |
-
# def find_by_key(data, target):
|
| 327 |
-
# for key, value in data.items():
|
| 328 |
-
# if isinstance(value, dict):
|
| 329 |
-
# yield from find_by_key(value, target)
|
| 330 |
-
# elif key == target:
|
| 331 |
-
# yield value
|
| 332 |
-
# for x in find_by_key(query_, "metadata.category.keyword"):
|
| 333 |
-
# imp_item = x
|
| 334 |
|
| 335 |
|
| 336 |
###### find the main subject of the query
|
|
@@ -405,46 +360,7 @@ def get_new_query_res(query):
|
|
| 405 |
|
| 406 |
st.session_state.input_rewritten_query = {"query":query_}
|
| 407 |
print(st.session_state.input_rewritten_query)
|
| 408 |
-
|
| 409 |
-
# amazon_rekognition.call(st.session_state.input_text,st.session_state.input_rekog_label)
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
# #return searchWithNewQuery(st.session_state.input_rewritten_query)
|
| 413 |
-
|
| 414 |
-
# def searchWithNewQuery(new_query):
|
| 415 |
-
# response = aos_client.search(
|
| 416 |
-
# body = new_query,
|
| 417 |
-
# index = "demo-retail-rekognition"#'self-query-rewrite-retail',
|
| 418 |
-
# #pipeline = 'RAG-Search-Pipeline'
|
| 419 |
-
# )
|
| 420 |
-
|
| 421 |
-
# hits = response['hits']['hits']
|
| 422 |
-
# print("rewrite-------------------------")
|
| 423 |
-
# arr = []
|
| 424 |
-
# for doc in hits:
|
| 425 |
-
# # if('b5/b5319e00' in doc['_source']['image_s3_url'] ):
|
| 426 |
-
# # filter_out +=1
|
| 427 |
-
# # continue
|
| 428 |
-
|
| 429 |
-
# res_ = {"desc":doc['_source']['text'],"image_url":doc['_source']['metadata']['image_s3_url']}
|
| 430 |
-
# if('highlight' in doc):
|
| 431 |
-
# res_['highlight'] = doc['highlight']['text']
|
| 432 |
-
# # if('caption_embedding' in doc['_source']):
|
| 433 |
-
# # res_['sparse'] = doc['_source']['caption_embedding']
|
| 434 |
-
# # if('query_sparse' in response_ and len(arr) ==0 ):
|
| 435 |
-
# # res_['query_sparse'] = response_["query_sparse"]
|
| 436 |
-
# res_['id'] = doc['_id']
|
| 437 |
-
# res_['score'] = doc['_score']
|
| 438 |
-
# res_['title'] = doc['_source']['text']
|
| 439 |
-
|
| 440 |
-
# arr.append(res_)
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
# return arr
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
|
| 449 |
|
| 450 |
|
|
|
|
| 252 |
query = st.session_state.input_rekog_label
|
| 253 |
if(st.session_state.input_is_rewrite_query == 'enabled'):
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
res = invoke_models.invoke_llm_model( prompt_.format(query=query,schema = schema) ,False)
|
| 256 |
inter_query = res[7:-3].replace('\\"',"'").replace("\n","")
|
| 257 |
print("inter_query")
|
|
|
|
| 284 |
draft_new_query['bool']['must'].append(q_dash)
|
| 285 |
else:
|
| 286 |
draft_new_query['bool']['should'].append(q_dash)
|
| 287 |
+
|
| 288 |
+
query_ = draft_new_query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
|
| 291 |
###### find the main subject of the query
|
|
|
|
| 360 |
|
| 361 |
st.session_state.input_rewritten_query = {"query":query_}
|
| 362 |
print(st.session_state.input_rewritten_query)
|
| 363 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
|
| 366 |
|